diff --git a/README.md b/README.md index 83ce0ac16..476aae6de 100644 --- a/README.md +++ b/README.md @@ -28,14 +28,15 @@ We provide the following recipes: - [yesno][yesno] - [LibriSpeech][librispeech] + - [GigaSpeech][gigaspeech] - [Aishell][aishell] + - [Aishell2][aishell2] + - [Aishell4][aishell4] - [TIMIT][timit] - [TED-LIUM3][tedlium3] - - [GigaSpeech][gigaspeech] - [Aidatatang_200zh][aidatatang_200zh] - [WenetSpeech][wenetspeech] - [Alimeeting][alimeeting] - - [Aishell4][aishell4] - [TAL_CSASR][tal_csasr] ### yesno @@ -46,9 +47,7 @@ Training takes less than 30 seconds and gives you the following WER: ``` [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ] ``` -We do provide a Colab notebook for this recipe. - -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing) +We provide a Colab notebook for this recipe: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing) ### LibriSpeech @@ -82,7 +81,7 @@ The WER for this model is: |-----|------------|------------| | WER | 6.59 | 17.69 | -We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing) +We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing) #### Transducer: Conformer encoder + LSTM decoder @@ -118,19 +117,54 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles | | test-clean | test-other | |-----|------------|------------| -| WER | 2.57 | 5.95 | +| WER | 2.15 | 5.20 | + +Note: No auxiliary losses are used in the training and no LMs are used +in the decoding. #### k2 pruned RNN-T + GigaSpeech | | test-clean | test-other | |-----|------------|------------| -| WER | 2.00 | 4.63 | +| WER | 1.78 | 4.08 | + +Note: No auxiliary losses are used in the training and no LMs are used +in the decoding. + +#### k2 pruned RNN-T + GigaSpeech + CommonVoice + +| | test-clean | test-other | +|-----|------------|------------| +| WER | 1.90 | 3.98 | + +Note: No auxiliary losses are used in the training and no LMs are used +in the decoding. + + +### GigaSpeech + +We provide two models for this recipe: [Conformer CTC model][GigaSpeech_conformer_ctc] +and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2]. + +#### Conformer CTC + +| | Dev | Test | +|-----|-------|-------| +| WER | 10.47 | 10.58 | + +#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss + +| | Dev | Test | +|----------------------|-------|-------| +| greedy search | 10.51 | 10.73 | +| fast beam search | 10.50 | 10.69 | +| modified beam search | 10.40 | 10.51 | ### Aishell -We provide two models for this recipe: [conformer CTC model][Aishell_conformer_ctc] -and [TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc]. +We provide three models for this recipe: [conformer CTC model][Aishell_conformer_ctc], +[TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc], and [Transducer Stateless Model][Aishell_pruned_transducer_stateless7], #### Conformer CTC Model @@ -140,20 +174,6 @@ The best CER we currently have is: |-----|------| | CER | 4.26 | - -We provide a Colab notebook to run a pre-trained conformer CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing) - -#### Transducer Stateless Model - -The best CER we currently have is: - -| | test | -|-----|------| -| CER | 4.68 | - - -We provide a Colab notebook to run a pre-trained TransducerStateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing) - #### TDNN LSTM CTC Model The CER for this model is: @@ -164,6 +184,46 @@ The CER for this model is: We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing) +#### Transducer Stateless Model + +The best CER we currently have is: + +| | test | +|-----|------| +| CER | 4.38 | + +We provide a Colab notebook to run a pre-trained TransducerStateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing) + + +### Aishell2 + +We provide one model for this recipe: [Transducer Stateless Model][Aishell2_pruned_transducer_stateless5]. + +#### Transducer Stateless Model + +The best WER we currently have is: + +| | dev-ios | test-ios | +|-----|------------|------------| +| WER | 5.32 | 5.56 | + + +### Aishell4 + +We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5]. + +#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets) + +The best CER we currently have is: + +| | test | +|-----|------------| +| CER | 29.08 | + + +We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing) + + ### TIMIT We provide two models for this recipe: [TDNN LSTM CTC model][TIMIT_tdnn_lstm_ctc] @@ -187,7 +247,8 @@ The PER for this model is: |--|--| |PER| 17.66% | -We provide a Colab notebook to run a pre-trained TDNN LiGRU CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/11IT-k4HQIgQngXz1uvWsEYktjqQt7Tmb?usp=sharing) +We provide a Colab notebook to run a pre-trained TDNN LiGRU CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing) + ### TED-LIUM3 @@ -215,24 +276,6 @@ The best WER using modified beam search with beam size 4 is: We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing) -### GigaSpeech - -We provide two models for this recipe: [Conformer CTC model][GigaSpeech_conformer_ctc] -and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2]. - -#### Conformer CTC - -| | Dev | Test | -|-----|-------|-------| -| WER | 10.47 | 10.58 | - -#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss - -| | Dev | Test | -|----------------------|-------|-------| -| greedy search | 10.51 | 10.73 | -| fast beam search | 10.50 | 10.69 | -| modified beam search | 10.40 | 10.51 | ### Aidatatang_200zh @@ -248,6 +291,7 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1wNSnSj3T5oOctbh5IGCa393gKOoQw2GH?usp=sharing) + ### WenetSpeech We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2] and [Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless5]. @@ -284,20 +328,6 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing) -### Aishell4 - -We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5]. - -#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets) - -The best CER(%) results: -| | test | -|----------------------|--------| -| greedy search | 29.89 | -| fast beam search | 28.91 | -| modified beam search | 29.08 | - -We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing) ### TAL_CSASR @@ -333,6 +363,9 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [LibriSpeech_transducer_stateless]: egs/librispeech/ASR/transducer_stateless [Aishell_tdnn_lstm_ctc]: egs/aishell/ASR/tdnn_lstm_ctc [Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc +[Aishell_pruned_transducer_stateless7]: egs/aishell/ASR/pruned_transducer_stateless7_bbpe +[Aishell2_pruned_transducer_stateless5]: egs/aishell2/ASR/pruned_transducer_stateless5 +[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5 [TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc [TIMIT_tdnn_ligru_ctc]: egs/timit/ASR/tdnn_ligru_ctc [TED-LIUM3_transducer_stateless]: egs/tedlium3/ASR/transducer_stateless @@ -343,17 +376,17 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2 [WenetSpeech_pruned_transducer_stateless5]: egs/wenetspeech/ASR/pruned_transducer_stateless5 [Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2 -[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5 [TAL_CSASR_pruned_transducer_stateless5]: egs/tal_csasr/ASR/pruned_transducer_stateless5 [yesno]: egs/yesno/ASR [librispeech]: egs/librispeech/ASR [aishell]: egs/aishell/ASR +[aishell2]: egs/aishell2/ASR +[aishell4]: egs/aishell4/ASR [timit]: egs/timit/ASR [tedlium3]: egs/tedlium3/ASR [gigaspeech]: egs/gigaspeech/ASR [aidatatang_200zh]: egs/aidatatang_200zh/ASR [wenetspeech]: egs/wenetspeech/ASR [alimeeting]: egs/alimeeting/ASR -[aishell4]: egs/aishell4/ASR [tal_csasr]: egs/tal_csasr/ASR [k2]: https://github.com/k2-fsa/k2 diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index 5b9fb2664..738b24ab2 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -3,64 +3,91 @@ Installation ============ -- |os| -- |device| -- |python_versions| -- |torch_versions| -- |k2_versions| -.. |os| image:: ./images/os-Linux_macOS-ff69b4.svg - :alt: Supported operating systems - -.. |device| image:: ./images/device-CPU_CUDA-orange.svg - :alt: Supported devices - -.. |python_versions| image:: ./images/python-gt-v3.6-blue.svg - :alt: Supported python versions - -.. |torch_versions| image:: ./images/torch-gt-v1.6.0-green.svg - :alt: Supported PyTorch versions - -.. |k2_versions| image:: ./images/k2-gt-v1.9-blueviolet.svg - :alt: Supported k2 versions ``icefall`` depends on `k2 `_ and `lhotse `_. -We recommend you to use the following steps to install the dependencies. +We recommend that you use the following steps to install the dependencies. -- (0) Install PyTorch and torchaudio -- (1) Install k2 -- (2) Install lhotse +- (0) Install CUDA toolkit and cuDNN +- (1) Install PyTorch and torchaudio +- (2) Install k2 +- (3) Install lhotse + +.. caution:: + + 99% users who have issues about the installation are using conda. + +.. caution:: + + 99% users who have issues about the installation are using conda. + +.. caution:: + + 99% users who have issues about the installation are using conda. + +.. hint:: + + We suggest that you use ``pip install`` to install PyTorch. + + You can use the following command to create a virutal environment in Python: + + .. code-block:: bash + + python3 -m venv ./my_env + source ./my_env/bin/activate .. caution:: Installation order matters. -(0) Install PyTorch and torchaudio +(0) Install CUDA toolkit and cuDNN +---------------------------------- + +Please refer to +``_ +to install CUDA and cuDNN. + + +(1) Install PyTorch and torchaudio ---------------------------------- Please refer ``_ to install PyTorch and torchaudio. +.. hint:: -(1) Install k2 + You can also go to ``_ + to download pre-compiled wheels and install them. + +.. caution:: + + Please install torch and torchaudio at the same time. + + +(2) Install k2 -------------- Please refer to ``_ to install ``k2``. -.. CAUTION:: +.. caution:: - You need to install ``k2`` with a version at least **v1.9**. + Please don't change your installed PyTorch after you have installed k2. -.. HINT:: +.. note:: - If you have already installed PyTorch and don't want to replace it, - please install a version of ``k2`` that is compiled against the version - of PyTorch you are using. + We suggest that you install k2 from source by following + ``_ + or + ``_. -(2) Install lhotse +.. hint:: + + Please always install the latest version of k2. + +(3) Install lhotse ------------------ Please refer to ``_ @@ -75,8 +102,7 @@ to install ``lhotse``. to install the latest version of lhotse. - -(3) Download icefall +(4) Download icefall -------------------- ``icefall`` is a collection of Python scripts; what you need is to download it @@ -338,44 +364,42 @@ The log of running ``./prepare.sh`` is: .. code-block:: - 2021-08-23 19:27:26 (prepare.sh:24:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download - 2021-08-23 19:27:26 (prepare.sh:27:main) stage 0: Download data - Downloading waves_yesno.tar.gz: 4.49MB [00:03, 1.39MB/s] - 2021-08-23 19:27:30 (prepare.sh:36:main) Stage 1: Prepare yesno manifest - 2021-08-23 19:27:31 (prepare.sh:42:main) Stage 2: Compute fbank for yesno - 2021-08-23 19:27:32,803 INFO [compute_fbank_yesno.py:52] Processing train - Extracting and storing features: 100%|_______________________________________________________________| 90/90 [00:01<00:00, 80.57it/s] - 2021-08-23 19:27:34,085 INFO [compute_fbank_yesno.py:52] Processing test - Extracting and storing features: 100%|______________________________________________________________| 30/30 [00:00<00:00, 248.21it/s] - 2021-08-23 19:27:34 (prepare.sh:48:main) Stage 3: Prepare lang - 2021-08-23 19:27:35 (prepare.sh:63:main) Stage 4: Prepare G - /tmp/pip-install-fcordre9/kaldilm_6899d26f2d684ad48f21025950cd2866/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Rea - d(std::istream&):79 - [I] Reading \data\ section. - /tmp/pip-install-fcordre9/kaldilm_6899d26f2d684ad48f21025950cd2866/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Rea - d(std::istream&):140 - [I] Reading \1-grams: section. - 2021-08-23 19:27:35 (prepare.sh:89:main) Stage 5: Compile HLG - 2021-08-23 19:27:35,928 INFO [compile_hlg.py:120] Processing data/lang_phone - 2021-08-23 19:27:35,929 INFO [lexicon.py:116] Converting L.pt to Linv.pt - 2021-08-23 19:27:35,931 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3 - 2021-08-23 19:27:35,932 INFO [compile_hlg.py:52] Loading G.fst.txt - 2021-08-23 19:27:35,932 INFO [compile_hlg.py:62] Intersecting L and G - 2021-08-23 19:27:35,933 INFO [compile_hlg.py:64] LG shape: (4, None) - 2021-08-23 19:27:35,933 INFO [compile_hlg.py:66] Connecting LG - 2021-08-23 19:27:35,933 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None) - 2021-08-23 19:27:35,933 INFO [compile_hlg.py:70] - 2021-08-23 19:27:35,933 INFO [compile_hlg.py:71] Determinizing LG - 2021-08-23 19:27:35,934 INFO [compile_hlg.py:74] - 2021-08-23 19:27:35,934 INFO [compile_hlg.py:76] Connecting LG after k2.determinize - 2021-08-23 19:27:35,934 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG - 2021-08-23 19:27:35,934 INFO [compile_hlg.py:87] LG shape after k2.remove_epsilon: (6, None) - 2021-08-23 19:27:35,935 INFO [compile_hlg.py:92] Arc sorting LG - 2021-08-23 19:27:35,935 INFO [compile_hlg.py:95] Composing H and LG - 2021-08-23 19:27:35,935 INFO [compile_hlg.py:102] Connecting LG - 2021-08-23 19:27:35,935 INFO [compile_hlg.py:105] Arc sorting LG - 2021-08-23 19:27:35,936 INFO [compile_hlg.py:107] HLG.shape: (8, None) - 2021-08-23 19:27:35,936 INFO [compile_hlg.py:123] Saving HLG.pt to data/lang_phone + 2023-05-12 17:55:21 (prepare.sh:27:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download + 2023-05-12 17:55:21 (prepare.sh:30:main) Stage 0: Download data + /tmp/icefall/egs/yesno/ASR/download/waves_yesno.tar.gz: 100%|_______________________________________________________________| 4.70M/4.70M [06:54<00:00, 11.4kB/s] + 2023-05-12 18:02:19 (prepare.sh:39:main) Stage 1: Prepare yesno manifest + 2023-05-12 18:02:21 (prepare.sh:45:main) Stage 2: Compute fbank for yesno + 2023-05-12 18:02:23,199 INFO [compute_fbank_yesno.py:65] Processing train + Extracting and storing features: 100%|_______________________________________________________________| 90/90 [00:00<00:00, 212.60it/s] + 2023-05-12 18:02:23,640 INFO [compute_fbank_yesno.py:65] Processing test + Extracting and storing features: 100%|_______________________________________________________________| 30/30 [00:00<00:00, 304.53it/s] + 2023-05-12 18:02:24 (prepare.sh:51:main) Stage 3: Prepare lang + 2023-05-12 18:02:26 (prepare.sh:66:main) Stage 4: Prepare G + /project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):79 + [I] Reading \data\ section. + /project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):140 + [I] Reading \1-grams: section. + 2023-05-12 18:02:26 (prepare.sh:92:main) Stage 5: Compile HLG + 2023-05-12 18:02:28,581 INFO [compile_hlg.py:124] Processing data/lang_phone + 2023-05-12 18:02:28,582 INFO [lexicon.py:171] Converting L.pt to Linv.pt + 2023-05-12 18:02:28,609 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3 + 2023-05-12 18:02:28,610 INFO [compile_hlg.py:52] Loading G.fst.txt + 2023-05-12 18:02:28,611 INFO [compile_hlg.py:62] Intersecting L and G + 2023-05-12 18:02:28,613 INFO [compile_hlg.py:64] LG shape: (4, None) + 2023-05-12 18:02:28,613 INFO [compile_hlg.py:66] Connecting LG + 2023-05-12 18:02:28,614 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None) + 2023-05-12 18:02:28,614 INFO [compile_hlg.py:70] + 2023-05-12 18:02:28,614 INFO [compile_hlg.py:71] Determinizing LG + 2023-05-12 18:02:28,615 INFO [compile_hlg.py:74] + 2023-05-12 18:02:28,615 INFO [compile_hlg.py:76] Connecting LG after k2.determinize + 2023-05-12 18:02:28,615 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG + 2023-05-12 18:02:28,616 INFO [compile_hlg.py:91] LG shape after k2.remove_epsilon: (6, None) + 2023-05-12 18:02:28,617 INFO [compile_hlg.py:96] Arc sorting LG + 2023-05-12 18:02:28,617 INFO [compile_hlg.py:99] Composing H and LG + 2023-05-12 18:02:28,619 INFO [compile_hlg.py:106] Connecting LG + 2023-05-12 18:02:28,619 INFO [compile_hlg.py:109] Arc sorting LG + 2023-05-12 18:02:28,619 INFO [compile_hlg.py:111] HLG.shape: (8, None) + 2023-05-12 18:02:28,619 INFO [compile_hlg.py:127] Saving HLG.pt to data/lang_phone Training @@ -408,49 +432,53 @@ The training log is given below: .. code-block:: - 2021-08-23 19:30:31,072 INFO [train.py:465] Training started - 2021-08-23 19:30:31,072 INFO [train.py:466] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, - 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, ' - best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_doub - le_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'feature_dir': PosixPath('data/fbank' - ), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0 - , 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 2} - 2021-08-23 19:30:31,074 INFO [lexicon.py:113] Loading pre-compiled data/lang_phone/Linv.pt - 2021-08-23 19:30:31,098 INFO [asr_datamodule.py:146] About to get train cuts - 2021-08-23 19:30:31,098 INFO [asr_datamodule.py:240] About to get train cuts - 2021-08-23 19:30:31,102 INFO [asr_datamodule.py:149] About to create train dataset - 2021-08-23 19:30:31,102 INFO [asr_datamodule.py:200] Using SingleCutSampler. - 2021-08-23 19:30:31,102 INFO [asr_datamodule.py:206] About to create train dataloader - 2021-08-23 19:30:31,102 INFO [asr_datamodule.py:219] About to get test cuts - 2021-08-23 19:30:31,102 INFO [asr_datamodule.py:246] About to get test cuts - 2021-08-23 19:30:31,357 INFO [train.py:416] Epoch 0, batch 0, batch avg loss 1.0789, total avg loss: 1.0789, batch size: 4 - 2021-08-23 19:30:31,848 INFO [train.py:416] Epoch 0, batch 10, batch avg loss 0.5356, total avg loss: 0.7556, batch size: 4 - 2021-08-23 19:30:32,301 INFO [train.py:432] Epoch 0, valid loss 0.9972, best valid loss: 0.9972 best valid epoch: 0 - 2021-08-23 19:30:32,805 INFO [train.py:416] Epoch 0, batch 20, batch avg loss 0.2436, total avg loss: 0.5717, batch size: 3 - 2021-08-23 19:30:33,109 INFO [train.py:432] Epoch 0, valid loss 0.4167, best valid loss: 0.4167 best valid epoch: 0 - 2021-08-23 19:30:33,121 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-0.pt - 2021-08-23 19:30:33,325 INFO [train.py:416] Epoch 1, batch 0, batch avg loss 0.2214, total avg loss: 0.2214, batch size: 5 - 2021-08-23 19:30:33,798 INFO [train.py:416] Epoch 1, batch 10, batch avg loss 0.0781, total avg loss: 0.1343, batch size: 5 - 2021-08-23 19:30:34,065 INFO [train.py:432] Epoch 1, valid loss 0.0859, best valid loss: 0.0859 best valid epoch: 1 - 2021-08-23 19:30:34,556 INFO [train.py:416] Epoch 1, batch 20, batch avg loss 0.0421, total avg loss: 0.0975, batch size: 3 - 2021-08-23 19:30:34,810 INFO [train.py:432] Epoch 1, valid loss 0.0431, best valid loss: 0.0431 best valid epoch: 1 - 2021-08-23 19:30:34,824 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-1.pt + 2023-05-12 18:04:59,759 INFO [train.py:481] Training started + 2023-05-12 18:04:59,759 INFO [train.py:482] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, + 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, + 'reduction': 'sum', 'use_double_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'seed': 42, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, + 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, + 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3b7f09fa35e72589914f67089c0da9f196a92ca4', 'k2-git-date': 'Mon May 8 22:58:45 2023', + 'lhotse-version': '1.15.0.dev+git.6fcfced.clean', 'torch-version': '2.0.0+cu118', 'torch-cuda-available': False, 'torch-cuda-version': '11.8', 'python-version': '3.1', 'icefall-git-branch': 'master', + 'icefall-git-sha1': '30bde4b-clean', 'icefall-git-date': 'Thu May 11 17:37:47 2023', 'icefall-path': '/tmp/icefall', + 'k2-path': 'tmp/lib/python3.10/site-packages/k2-1.24.3.dev20230512+cuda11.8.torch2.0.0-py3.10-linux-x86_64.egg/k2/__init__.py', + 'lhotse-path': 'tmp/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'host', 'IP address': '0.0.0.0'}} + 2023-05-12 18:04:59,761 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt + 2023-05-12 18:04:59,764 INFO [train.py:495] device: cpu + 2023-05-12 18:04:59,791 INFO [asr_datamodule.py:146] About to get train cuts + 2023-05-12 18:04:59,791 INFO [asr_datamodule.py:244] About to get train cuts + 2023-05-12 18:04:59,852 INFO [asr_datamodule.py:149] About to create train dataset + 2023-05-12 18:04:59,852 INFO [asr_datamodule.py:199] Using SingleCutSampler. + 2023-05-12 18:04:59,852 INFO [asr_datamodule.py:205] About to create train dataloader + 2023-05-12 18:04:59,853 INFO [asr_datamodule.py:218] About to get test cuts + 2023-05-12 18:04:59,853 INFO [asr_datamodule.py:252] About to get test cuts + 2023-05-12 18:04:59,986 INFO [train.py:422] Epoch 0, batch 0, loss[loss=1.065, over 2436.00 frames. ], tot_loss[loss=1.065, over 2436.00 frames. ], batch size: 4 + 2023-05-12 18:05:00,352 INFO [train.py:422] Epoch 0, batch 10, loss[loss=0.4561, over 2828.00 frames. ], tot_loss[loss=0.7076, over 22192.90 frames. ], batch size: 4 + 2023-05-12 18:05:00,691 INFO [train.py:444] Epoch 0, validation loss=0.9002, over 18067.00 frames. + 2023-05-12 18:05:00,996 INFO [train.py:422] Epoch 0, batch 20, loss[loss=0.2555, over 2695.00 frames. ], tot_loss[loss=0.484, over 34971.47 frames. ], batch size: 5 + 2023-05-12 18:05:01,217 INFO [train.py:444] Epoch 0, validation loss=0.4688, over 18067.00 frames. + 2023-05-12 18:05:01,251 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-0.pt + 2023-05-12 18:05:01,389 INFO [train.py:422] Epoch 1, batch 0, loss[loss=0.2532, over 2436.00 frames. ], tot_loss[loss=0.2532, over 2436.00 frames. ], batch size: 4 + 2023-05-12 18:05:01,637 INFO [train.py:422] Epoch 1, batch 10, loss[loss=0.1139, over 2828.00 frames. ], tot_loss[loss=0.1592, over 22192.90 frames. ], batch size: 4 + 2023-05-12 18:05:01,859 INFO [train.py:444] Epoch 1, validation loss=0.1629, over 18067.00 frames. + 2023-05-12 18:05:02,094 INFO [train.py:422] Epoch 1, batch 20, loss[loss=0.0767, over 2695.00 frames. ], tot_loss[loss=0.118, over 34971.47 frames. ], batch size: 5 + 2023-05-12 18:05:02,350 INFO [train.py:444] Epoch 1, validation loss=0.06778, over 18067.00 frames. + 2023-05-12 18:05:02,395 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-1.pt ... ... - 2021-08-23 19:30:49,657 INFO [train.py:416] Epoch 13, batch 0, batch avg loss 0.0109, total avg loss: 0.0109, batch size: 5 - 2021-08-23 19:30:49,984 INFO [train.py:416] Epoch 13, batch 10, batch avg loss 0.0093, total avg loss: 0.0096, batch size: 4 - 2021-08-23 19:30:50,239 INFO [train.py:432] Epoch 13, valid loss 0.0104, best valid loss: 0.0101 best valid epoch: 12 - 2021-08-23 19:30:50,569 INFO [train.py:416] Epoch 13, batch 20, batch avg loss 0.0092, total avg loss: 0.0096, batch size: 2 - 2021-08-23 19:30:50,819 INFO [train.py:432] Epoch 13, valid loss 0.0101, best valid loss: 0.0101 best valid epoch: 13 - 2021-08-23 19:30:50,835 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-13.pt - 2021-08-23 19:30:51,024 INFO [train.py:416] Epoch 14, batch 0, batch avg loss 0.0105, total avg loss: 0.0105, batch size: 5 - 2021-08-23 19:30:51,317 INFO [train.py:416] Epoch 14, batch 10, batch avg loss 0.0099, total avg loss: 0.0097, batch size: 4 - 2021-08-23 19:30:51,552 INFO [train.py:432] Epoch 14, valid loss 0.0108, best valid loss: 0.0101 best valid epoch: 13 - 2021-08-23 19:30:51,869 INFO [train.py:416] Epoch 14, batch 20, batch avg loss 0.0096, total avg loss: 0.0097, batch size: 5 - 2021-08-23 19:30:52,107 INFO [train.py:432] Epoch 14, valid loss 0.0102, best valid loss: 0.0101 best valid epoch: 13 - 2021-08-23 19:30:52,126 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-14.pt - 2021-08-23 19:30:52,128 INFO [train.py:537] Done! + 2023-05-12 18:05:14,789 INFO [train.py:422] Epoch 13, batch 0, loss[loss=0.01056, over 2436.00 frames. ], tot_loss[loss=0.01056, over 2436.00 frames. ], batch size: 4 + 2023-05-12 18:05:15,016 INFO [train.py:422] Epoch 13, batch 10, loss[loss=0.009022, over 2828.00 frames. ], tot_loss[loss=0.009985, over 22192.90 frames. ], batch size: 4 + 2023-05-12 18:05:15,271 INFO [train.py:444] Epoch 13, validation loss=0.01088, over 18067.00 frames. + 2023-05-12 18:05:15,497 INFO [train.py:422] Epoch 13, batch 20, loss[loss=0.01174, over 2695.00 frames. ], tot_loss[loss=0.01077, over 34971.47 frames. ], batch size: 5 + 2023-05-12 18:05:15,747 INFO [train.py:444] Epoch 13, validation loss=0.01087, over 18067.00 frames. + 2023-05-12 18:05:15,783 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-13.pt + 2023-05-12 18:05:15,921 INFO [train.py:422] Epoch 14, batch 0, loss[loss=0.01045, over 2436.00 frames. ], tot_loss[loss=0.01045, over 2436.00 frames. ], batch size: 4 + 2023-05-12 18:05:16,146 INFO [train.py:422] Epoch 14, batch 10, loss[loss=0.008957, over 2828.00 frames. ], tot_loss[loss=0.009903, over 22192.90 frames. ], batch size: 4 + 2023-05-12 18:05:16,374 INFO [train.py:444] Epoch 14, validation loss=0.01092, over 18067.00 frames. + 2023-05-12 18:05:16,598 INFO [train.py:422] Epoch 14, batch 20, loss[loss=0.01169, over 2695.00 frames. ], tot_loss[loss=0.01065, over 34971.47 frames. ], batch size: 5 + 2023-05-12 18:05:16,824 INFO [train.py:444] Epoch 14, validation loss=0.01077, over 18067.00 frames. + 2023-05-12 18:05:16,862 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-14.pt + 2023-05-12 18:05:16,865 INFO [train.py:555] Done! Decoding ~~~~~~~~ @@ -465,22 +493,25 @@ The decoding log is: .. code-block:: - 2021-08-23 19:35:30,192 INFO [decode.py:249] Decoding started - 2021-08-23 19:35:30,192 INFO [decode.py:250] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 2} - 2021-08-23 19:35:30,193 INFO [lexicon.py:113] Loading pre-compiled data/lang_phone/Linv.pt - 2021-08-23 19:35:30,213 INFO [decode.py:259] device: cpu - 2021-08-23 19:35:30,217 INFO [decode.py:279] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt'] - /tmp/icefall/icefall/checkpoint.py:146: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. - It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. - To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at /pytorch/aten/src/ATen/native/BinaryOps.cpp:450.) - avg[k] //= n - 2021-08-23 19:35:30,220 INFO [asr_datamodule.py:219] About to get test cuts - 2021-08-23 19:35:30,220 INFO [asr_datamodule.py:246] About to get test cuts - 2021-08-23 19:35:30,409 INFO [decode.py:190] batch 0/8, cuts processed until now is 4 - 2021-08-23 19:35:30,571 INFO [decode.py:228] The transcripts are stored in tdnn/exp/recogs-test_set.txt - 2021-08-23 19:35:30,572 INFO [utils.py:317] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ] - 2021-08-23 19:35:30,573 INFO [decode.py:236] Wrote detailed error stats to tdnn/exp/errs-test_set.txt - 2021-08-23 19:35:30,573 INFO [decode.py:299] Done! + 2023-05-12 18:08:30,482 INFO [decode.py:263] Decoding started + 2023-05-12 18:08:30,483 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, + 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'), + 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, + 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3b7f09fa35e72589914f67089c0da9f196a92ca4', 'k2-git-date': 'Mon May 8 22:58:45 2023', + 'lhotse-version': '1.15.0.dev+git.6fcfced.clean', 'torch-version': '2.0.0+cu118', 'torch-cuda-available': False, 'torch-cuda-version': '11.8', 'python-version': '3.1', 'icefall-git-branch': 'master', + 'icefall-git-sha1': '30bde4b-clean', 'icefall-git-date': 'Thu May 11 17:37:47 2023', 'icefall-path': '/tmp/icefall', + 'k2-path': '/tmp/lib/python3.10/site-packages/k2-1.24.3.dev20230512+cuda11.8.torch2.0.0-py3.10-linux-x86_64.egg/k2/__init__.py', + 'lhotse-path': '/tmp/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'host', 'IP address': '0.0.0.0'}} + 2023-05-12 18:08:30,483 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt + 2023-05-12 18:08:30,487 INFO [decode.py:273] device: cpu + 2023-05-12 18:08:30,513 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt'] + 2023-05-12 18:08:30,521 INFO [asr_datamodule.py:218] About to get test cuts + 2023-05-12 18:08:30,521 INFO [asr_datamodule.py:252] About to get test cuts + 2023-05-12 18:08:30,675 INFO [decode.py:204] batch 0/?, cuts processed until now is 4 + 2023-05-12 18:08:30,923 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt + 2023-05-12 18:08:30,924 INFO [utils.py:558] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ] + 2023-05-12 18:08:30,925 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt + 2023-05-12 18:08:30,925 INFO [decode.py:316] Done! **Congratulations!** You have successfully setup the environment and have run the first recipe in ``icefall``. diff --git a/docs/source/model-export/export-onnx.rst b/docs/source/model-export/export-onnx.rst index aa77204cb..fb952abb7 100644 --- a/docs/source/model-export/export-onnx.rst +++ b/docs/source/model-export/export-onnx.rst @@ -3,6 +3,15 @@ Export to ONNX In this section, we describe how to export models to `ONNX`_. +.. hint:: + + Before you continue, please run: + + .. code-block:: bash + + pip install onnx + + In each recipe, there is a file called ``export-onnx.py``, which is used to export trained models to `ONNX`_. diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 4c730c4ae..aa18502c2 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -2,6 +2,57 @@ ### Aishell training result(Stateless Transducer) +#### Pruned transducer stateless 7 (zipformer) + +See + +[./pruned_transducer_stateless7_bbpe](./pruned_transducer_stateless7_bbpe) + +**Note**: The modeling units are byte level BPEs + +The best results I have gotten are: + +Vocab size | Greedy search(dev & test) | Modified beam search(dev & test) | Fast beam search (dev & test) | Fast beam search LG (dev & test) | comments +-- | -- | -- | -- | -- | -- +500 | 4.31 & 4.59 | 4.25 & 4.54 | 4.27 & 4.55 | 4.07 & 4.38 | --epoch 48 --avg 29 + +The training command: + +``` +export CUDA_VISIBLE_DEVICES="4,5,6,7" + +./pruned_transducer_stateless7_bbpe/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --max-duration 800 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --exp-dir pruned_transducer_stateless7_bbpe/exp \ + --lr-epochs 6 \ + --master-port 12535 +``` + +The decoding command: + +``` +for m in greedy_search modified_beam_search fast_beam_search fast_beam_search_LG; do + ./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 48 \ + --avg 29 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-sym-per-frame 1 \ + --ngram-lm-scale 0.25 \ + --ilme-scale 0.2 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --max-duration 2000 \ + --decoding-method $m +done +``` + +The pretrained model is available at: https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe + + #### Pruned transducer stateless 3 See @@ -75,7 +126,7 @@ for epoch in 29; do done ``` -We provide the option of shallow fusion with a RNN language model. The pre-trained language model is +We provide the option of shallow fusion with a RNN language model. The pre-trained language model is available at . To decode with the language model, please use the following command: diff --git a/egs/aishell/ASR/local/compile_lg.py b/egs/aishell/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/aishell/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/aishell/ASR/local/prepare_char.py b/egs/aishell/ASR/local/prepare_char.py index 6b440dfb3..8cc0502c2 100755 --- a/egs/aishell/ASR/local/prepare_char.py +++ b/egs/aishell/ASR/local/prepare_char.py @@ -33,6 +33,7 @@ and generates the following files in the directory `lang_dir`: - tokens.txt """ +import argparse import re from pathlib import Path from typing import Dict, List @@ -189,8 +190,22 @@ def generate_tokens(text_file: str) -> Dict[str, int]: return tokens +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + return parser.parse_args() + + def main(): - lang_dir = Path("data/lang_char") + args = get_args() + lang_dir = Path(args.lang_dir) text_file = lang_dir / "text" word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") diff --git a/egs/aishell/ASR/local/prepare_lang_bbpe.py b/egs/aishell/ASR/local/prepare_lang_bbpe.py new file mode 100755 index 000000000..ddd90622e --- /dev/null +++ b/egs/aishell/ASR/local/prepare_lang_bbpe.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" + +This script takes as input `lang_dir`, which should contain:: + + - lang_dir/bbpe.model, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import sentencepiece as spm +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + +from icefall.byte_utils import byte_encode +from icefall.utils import str2bool, tokenize_by_CJK_char + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def generate_lexicon( + model_file: str, words: List[str], oov: str +) -> Tuple[Lexicon, Dict[str, int]]: + """Generate a lexicon from a BPE model. + + Args: + model_file: + Path to a sentencepiece model. + words: + A list of strings representing words. + oov: + The out of vocabulary word in lexicon. + Returns: + Return a tuple with two elements: + - A dict whose keys are words and values are the corresponding + word pieces. + - A dict representing the token symbol, mapping from tokens to IDs. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + + # Convert word to word piece IDs instead of word piece strings + # to avoid OOV tokens. + encode_words = [byte_encode(tokenize_by_CJK_char(w)) for w in words] + words_pieces_ids: List[List[int]] = sp.encode(encode_words, out_type=int) + + # Now convert word piece IDs back to word piece strings. + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + + lexicon = [] + for word, pieces in zip(words, words_pieces): + lexicon.append((word, pieces)) + + lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) + + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} + + return lexicon, token2id + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + parser.add_argument( + "--oov", + type=str, + default="", + help="The out of vocabulary word in lexicon.", + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + + See "test/test_bpe_lexicon.py" for usage. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + model_file = lang_dir / "bbpe.model" + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = ["", "!SIL", "", args.oov, "#0", "", ""] + + for w in excluded: + if w in words: + words.remove(w) + + lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/local/train_bbpe_model.py b/egs/aishell/ASR/local/train_bbpe_model.py new file mode 100755 index 000000000..d231d5d77 --- /dev/null +++ b/egs/aishell/ASR/local/train_bbpe_model.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import re +import shutil +import tempfile +from pathlib import Path + +import sentencepiece as spm +from icefall import byte_encode, tokenize_by_CJK_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def _convert_to_bchar(in_path: str, out_path: str): + with open(out_path, "w") as f: + for line in open(in_path, "r").readlines(): + f.write(byte_encode(tokenize_by_CJK_char(line)) + "\n") + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + temp = tempfile.NamedTemporaryFile() + train_text = temp.name + + _convert_to_bchar(args.transcript, train_text) + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bbpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 3e0d5f51b..b763d72c1 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -35,6 +35,15 @@ dl_dir=$PWD/download . shared/parse_options.sh || exit 1 +# vocab size for sentence piece models. +# It will generate data/lang_bbpe_xxx, +# data/lang_bbpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 2000 + # 1000 + 500 +) + # All files generated by this script are saved in "data". # You can safely remove "data" and rerun this script to regenerate it. mkdir -p data @@ -47,20 +56,6 @@ log() { log "dl_dir: $dl_dir" -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "stage -1: Download LM" - # We assume that you have installed the git-lfs, if not, you could install it - # using: `sudo apt-get install git-lfs && git-lfs install` - git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1) - - if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then - git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm - pushd $dl_dir/lm - git lfs pull --include "3-gram.unpruned.arpa" - popd - fi -fi - if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "stage 0: Download data" @@ -134,7 +129,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi lang_phone_dir=data/lang_phone -lang_char_dir=data/lang_char if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Prepare phone based lang" mkdir -p $lang_phone_dir @@ -183,45 +177,107 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi fi +lang_char_dir=data/lang_char if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Prepare char based lang" mkdir -p $lang_char_dir # We reuse words.txt from phone based lexicon # so that the two can share G.pt later. - cp $lang_phone_dir/words.txt $lang_char_dir + + # The transcripts in training set, generated in stage 5 + cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt | - cut -d " " -f 2- | sed -e 's/[ \t\r\n]*//g' > $lang_char_dir/text + cut -d " " -f 2- > $lang_char_dir/text + + (echo ' 0'; echo '!SIL 1'; echo ' 2'; echo ' 3';) \ + > $lang_char_dir/words.txt + + cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ + | awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt + + num_lines=$(< $lang_char_dir/words.txt wc -l) + (echo "#0 $num_lines"; echo " $(($num_lines + 1))"; echo " $(($num_lines + 2))";) \ + >> $lang_char_dir/words.txt if [ ! -f $lang_char_dir/L_disambig.pt ]; then - ./local/prepare_char.py + ./local/prepare_char.py --lang-dir $lang_char_dir fi fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare G" - # We assume you have install kaldilm, if not, please install - # it using: pip install kaldilm + log "Stage 7: Prepare Byte BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + mkdir -p $lang_dir + + cp $lang_char_dir/words.txt $lang_dir + cp $lang_char_dir/text $lang_dir + + if [ ! -f $lang_dir/bbpe.model ]; then + ./local/train_bbpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/text + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bbpe.py --lang-dir $lang_dir + fi + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare G" mkdir -p data/lm - if [ ! -f data/lm/G_3_gram.fst.txt ]; then + + # Train LM on transcripts + if [ ! -f data/lm/3-gram.unpruned.arpa ]; then + python3 ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text $lang_char_dir/transcript_words.txt \ + -lm data/lm/3-gram.unpruned.arpa + fi + + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then # It is used in building HLG python3 -m kaldilm \ --read-symbol-table="$lang_phone_dir/words.txt" \ --disambig-symbol='#0' \ --max-order=3 \ - $dl_dir/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt + data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_phone.fst.txt + + python3 -m kaldilm \ + --read-symbol-table="$lang_char_dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt fi fi -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Compile HLG" - ./local/compile_hlg.py --lang-dir $lang_phone_dir - ./local/compile_hlg.py --lang-dir $lang_char_dir +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile LG & HLG" + ./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone + ./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char + done + + ./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone + ./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char + done fi -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Generate LM training data" +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Generate LM training data" log "Processing char based data" out_dir=data/lm_training_char @@ -267,8 +323,8 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then fi -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Sort LM training data" +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Sort LM training data" # Sort LM training data by sentence length in descending order # for ease of training. # @@ -295,7 +351,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then --out-statistics $out_dir/statistics-test.txt fi -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then log "Stage 11: Train RNN LM model" python ../../../icefall/rnn_lm/train.py \ --start-epoch 0 \ diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/__init__.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py new file mode 100755 index 000000000..fcb0ebc4e --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py @@ -0,0 +1,819 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import ( + LmScorer, + NgramLm, + byte_encode, + smart_byte_decode, + tokenize_by_CJK_char, +) +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_500/bbpe.model", + help="Path to the byte BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bbpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + If you use fast_beam_search_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.25, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_LG": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + subtract_ilme=True, + ilme_scale=params.ilme_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + ref_texts = [] + for tx in supervisions["text"]: + ref_texts.append(byte_encode(tokenize_by_CJK_char(tx))) + + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(ref_texts), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(smart_byte_decode(sp.decode(hyp)).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + key += f"_ilme_scale_{params.ilme_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + params.suffix += f"-ilme-scale-{params.ilme_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bbpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + + test_cuts = aishell.test_cuts() + dev_cuts = aishell.valid_cuts() + + test_dl = aishell.test_dataloaders(test_cuts) + dev_dl = aishell.test_dataloaders(dev_cuts) + + test_sets = ["test", "dev"] + test_dls = [test_dl, dev_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py new file mode 120000 index 000000000..8283d8c5a --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py new file mode 100755 index 000000000..4e82b45d3 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_bbpe/export.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_bbpe/export.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_bbpe/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/aishell/ASR + ./pruned_transducer_stateless7_bbpe/decode.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bbpe_500/bbpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe + # You will find the pre-trained model in icefall_asr_aishell_pruned_transducer_stateless7_bbpe/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_bbpe/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_500/bbpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bbpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py new file mode 100755 index 000000000..0c43bf74b --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_bbpe/export.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --epoch 49 \ + --avg 28 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless7_bbpe/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7_bbpe/exp/cpu_jit.pt \ + --bpe-model ./data/lang_bbpe_500/bbpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from icefall import smart_byte_decode + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = smart_byte_decode(sp.decode(hyp)) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py new file mode 120000 index 000000000..0f0c3c90a --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py new file mode 120000 index 000000000..0d8bc665b --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py new file mode 120000 index 000000000..8a05abb5f --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py new file mode 100755 index 000000000..ea5bda4db --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_bbpe/export.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --epoch 48 \ + --avg 29 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_bbpe/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ + --bpe-model ./data/lang_bbpe_500/bbpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./pruned_transducer_stateless7_bbpe/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ + --bpe-model ./data/lang_bbpe_500/bbpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./pruned_transducer_stateless7_bbpe/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ + --bpe-model ./data/lang_bbpe_500/bbpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +Note: ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_bbpe/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import smart_byte_decode +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(smart_byte_decode(sp.decode(hyp)).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py new file mode 120000 index 000000000..5f9be9fe0 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py new file mode 120000 index 000000000..f9960e5c6 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py new file mode 120000 index 000000000..7ceac5d10 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py new file mode 100755 index 000000000..499badb14 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -0,0 +1,1261 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_bbpe/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_bbpe/exp \ + --max-duration 400 + +# For mix precision training: + +./pruned_transducer_stateless7_bbpe/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_bbpe/exp \ + --max-duration 800 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import byte_encode, diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, + tokenize_by_CJK_char, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_bbpe/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_500/bbpe.model", + help="Path to the Byte BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 2000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bbpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + aishell = AishellAsrDataModule(args) + + train_cuts = aishell.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 12 seconds + # + # Caution: There is a reason to select 12.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 12.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + def tokenize_and_encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = byte_encode(tokenize_by_CJK_char(text)) + c.supervisions[0].text = text + return c + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_cuts = train_cuts.map(tokenize_and_encode_text) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = aishell.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = aishell.valid_cuts() + valid_cuts = valid_cuts.map(tokenize_and_encode_text) + + valid_dl = aishell.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py new file mode 120000 index 000000000..f2f66041e --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index fc28e8dbc..efb32336a 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -21,7 +21,7 @@ import inspect import logging from functools import lru_cache from pathlib import Path -from typing import List +from typing import Any, Dict, List, Optional from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse.dataset import ( @@ -181,7 +181,16 @@ class AishellAsrDataModule: "with training dataset. ", ) - def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: + def train_dataloaders( + self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") @@ -277,6 +286,10 @@ class AishellAsrDataModule: ) logging.info("About to create train dataloader") + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + train_dl = DataLoader( train, sampler=train_sampler, @@ -325,7 +338,7 @@ class AishellAsrDataModule: return valid_dl def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") + logging.info("About to create test dataset") test = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 5a956fc9c..2ca0558ab 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,76 @@ ## Results +### pruned_transducer_stateless7 (zipformer + multidataset(LibriSpeech + GigaSpeech + CommonVoice 13.0)) + +See for more details. + +[pruned_transducer_stateless7](./pruned_transducer_stateless7) + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +You can use to deploy it. + +Number of model parameters: 70369391, i.e., 70.37 M + +| decoding method | test-clean | test-other | comment | +|----------------------|------------|------------|--------------------| +| greedy_search | 1.91 | 4.06 | --epoch 30 --avg 7 | +| modified_beam_search | 1.90 | 3.99 | --epoch 30 --avg 7 | +| fast_beam_search | 1.90 | 3.98 | --epoch 30 --avg 7 | + + +The training commands are: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./pruned_transducer_stateless7/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --use-multidataset 1 \ + --use-fp16 1 \ + --max-duration 750 \ + --exp-dir pruned_transducer_stateless7/exp +``` + +The decoding commands are: +```bash +# greedy_search +./pruned_transducer_stateless7/decode.py \ + --epoch 30 \ + --avg 7 \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +# modified_beam_search +./pruned_transducer_stateless7/decode.py \ + --epoch 30 \ + --avg 7 \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +# fast_beam_search +./pruned_transducer_stateless7/decode.py \ + --epoch 30 \ + --avg 7 \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +``` + ### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer + Multi-Dataset) #### [pruned_transducer_stateless7_streaming_multi](./pruned_transducer_stateless7_streaming_multi) @@ -215,11 +286,12 @@ done We also support decoding with neural network LMs. After combining with language models, the WERs are | decoding method | chunk size | test-clean | test-other | comment | decoding mode | |----------------------|------------|------------|------------|---------------------|----------------------| -| modified beam search | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming | -| modified beam search + RNNLM shallow fusion | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming | -| modified beam search + RNNLM nbest rescore | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search` | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search_lm_shallow_fusion` | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search_lm_rescore` | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search_lm_rescore_LODR` | 320ms | 2.52 | 6.73 | --epoch 30 --avg 9 | simulated streaming | -Please use the following command for RNNLM shallow fusion: +Please use the following command for `modified_beam_search_lm_shallow_fusion`: ```bash for lm_scale in $(seq 0.15 0.01 0.38); do for beam_size in 4 8 12; do @@ -246,7 +318,7 @@ for lm_scale in $(seq 0.15 0.01 0.38); do done ``` -Please use the following command for RNNLM rescore: +Please use the following command for `modified_beam_search_lm_rescore`: ```bash ./pruned_transducer_stateless7_streaming/decode.py \ --epoch 30 \ @@ -268,7 +340,32 @@ Please use the following command for RNNLM rescore: --lm-vocab-size 500 ``` -A well-trained RNNLM can be found here: . +Please use the following command for `modified_beam_search_lm_rescore_LODR`: +```bash +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 30 \ + --avg 9 \ + --use-averaged-model True \ + --beam-size 8 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search_lm_rescore_LODR \ + --use-shallow-fusion 0 \ + --lm-type rnn \ + --lm-exp-dir rnn_lm/exp \ + --lm-epoch 99 \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 3 \ + --lm-vocab-size 500 \ + --tokens-ngram 2 \ + --backoff-id 500 +``` + +A well-trained RNNLM can be found here: . The bi-gram used in LODR decoding +can be found here: . #### Smaller model diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 08dac6a7b..d19d50ae6 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -109,10 +109,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: logging.info("Removing disambiguation symbols on LG") - LG.labels[LG.labels >= first_token_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set LG.properties to None - LG.__dict__["_properties"] = None + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels assert isinstance(LG.aux_labels, k2.RaggedTensor) LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 19bf3bff4..709b14070 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -45,11 +45,18 @@ def get_args(): help="""Input and output directory. """, ) + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) return parser.parse_args() -def compile_LG(lang_dir: str) -> k2.Fsa: +def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: """ Args: lang_dir: @@ -61,15 +68,15 @@ def compile_LG(lang_dir: str) -> k2.Fsa: lexicon = Lexicon(lang_dir) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - if Path("data/lm/G_3_gram.pt").is_file(): - logging.info("Loading pre-compiled G_3_gram") - d = torch.load("data/lm/G_3_gram.pt") + if Path(f"data/lm/{lm}.pt").is_file(): + logging.info(f"Loading pre-compiled {lm}") + d = torch.load(f"data/lm/{lm}.pt") G = k2.Fsa.from_dict(d) else: - logging.info("Loading G_3_gram.fst.txt") - with open("data/lm/G_3_gram.fst.txt") as f: + logging.info(f"Loading {lm}.fst.txt") + with open(f"data/lm/{lm}.fst.txt") as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), "data/lm/G_3_gram.pt") + torch.save(G.as_dict(), f"data/lm/{lm}.pt") first_token_disambig_id = lexicon.token_table["#0"] first_word_disambig_id = lexicon.word_table["#0"] @@ -96,10 +103,11 @@ def compile_LG(lang_dir: str) -> k2.Fsa: logging.info("Removing disambiguation symbols on LG") - LG.labels[LG.labels >= first_token_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set LG.properties to None - LG.__dict__["_properties"] = None + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels assert isinstance(LG.aux_labels, k2.RaggedTensor) LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 @@ -126,7 +134,7 @@ def main(): logging.info(f"Processing {lang_dir}") - LG = compile_LG(lang_dir) + LG = compile_LG(lang_dir, args.lm) logging.info(f"Saving LG.pt to {lang_dir}") torch.save(LG.as_dict(), f"{lang_dir}/LG.pt") diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 3518db524..da1648d06 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -41,9 +41,57 @@ import os import shutil from pathlib import Path -from lhotse.utils import urlretrieve_progress from tqdm.auto import tqdm +# This function is copied from lhotse +def tqdm_urlretrieve_hook(t): + """Wraps tqdm instance. + Don't forget to close() or __exit__() + the tqdm instance once you're done with it (easiest using `with` syntax). + Example + ------- + >>> from urllib.request import urlretrieve + >>> with tqdm(...) as t: + ... reporthook = tqdm_urlretrieve_hook(t) + ... urlretrieve(..., reporthook=reporthook) + + Source: https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py + """ + last_b = [0] + + def update_to(b=1, bsize=1, tsize=None): + """ + b : int, optional + Number of blocks transferred so far [default: 1]. + bsize : int, optional + Size of each block (in tqdm units) [default: 1]. + tsize : int, optional + Total size (in tqdm units). If [default: None] or -1, + remains unchanged. + """ + if tsize not in (None, -1): + t.total = tsize + displayed = t.update((b - last_b[0]) * bsize) + last_b[0] = b + return displayed + + return update_to + + +# This function is copied from lhotse +def urlretrieve_progress(url, filename=None, data=None, desc=None): + """ + Works exactly like urllib.request.urlretrieve, but attaches a tqdm hook to + display a progress bar of the download. + Use "desc" argument to display a user-readable string that informs what is + being downloaded. + """ + from urllib.request import urlretrieve + + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=desc) as t: + reporthook = tqdm_urlretrieve_hook(t) + return urlretrieve(url=url, filename=filename, reporthook=reporthook, data=data) + def get_args(): parser = argparse.ArgumentParser() diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index b1d207049..8342d5212 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -184,7 +184,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then ./shared/convert-k2-to-openfst.py \ --olabels aux_labels \ $lang_dir/L_disambig.pt \ - $lang_dir/disambig_L.fst + $lang_dir/L_disambig.fst fi fi diff --git a/egs/librispeech/ASR/prepare_multidataset.sh b/egs/librispeech/ASR/prepare_multidataset.sh index c068305c0..c95b4d039 100755 --- a/egs/librispeech/ASR/prepare_multidataset.sh +++ b/egs/librispeech/ASR/prepare_multidataset.sh @@ -198,7 +198,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then ./shared/convert-k2-to-openfst.py \ --olabels aux_labels \ $lang_dir/L_disambig.pt \ - $lang_dir/disambig_L.fst + $lang_dir/L_disambig.fst fi fi @@ -328,46 +328,3 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then ./prepare_common_voice.sh fi fi - -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Create multidataset" - split_dir=data/fbank/multidataset_split_${num_splits} - if [ ! -f data/fbank/multidataset_split/.multidataset.done ]; then - mkdir -p $split_dir/multidataset - log "Split LibriSpeech" - if [ ! -f $split_dir/.librispeech_split.done ]; then - lhotse split $num_splits ./data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz $split_dir - touch $split_dir/.librispeech_split.done - fi - - if [[ "${multidataset[@]}" =~ "gigaspeech" ]]; then - log "Split GigaSpeech XL" - if [ ! -f $split_dir/.gigaspeech_XL_split.done ]; then - cd $split_dir - ln -sv ../gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz . - cd ../../.. - touch $split_dir/.gigaspeech_XL_split.done - fi - fi - - if [[ "${multidataset[@]}" =~ "commonvoice" ]]; then - log "Split CommonVoice" - if [ ! -f $split_dir/.cv-en_train_split.done ]; then - lhotse split $num_splits ./data/en/fbank/cv-en_cuts_train.jsonl.gz $split_dir - touch $split_dir/.cv-en_train_split.done - fi - fi - - if [ ! -f $split_dir/.multidataset_mix.done ]; then - log "Mix multidataset" - for ((seq=1; seq<=$num_splits; seq++)); do - fseq=$(printf "%04d" $seq) - gunzip -c $split_dir/*.*${fseq}.jsonl.gz | \ - shuf | gzip -c > $split_dir/multidataset/multidataset_cuts_train.${fseq}.jsonl.gz - done - touch $split_dir/.multidataset_mix.done - fi - - touch data/fbank/multidataset_split/.multidataset.done - fi -fi diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index c44a2ad3e..0280193ca 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -47,6 +47,8 @@ def fast_beam_search_one_best( max_states: int, max_contexts: int, temperature: float = 1.0, + subtract_ilme: bool = False, + ilme_scale: float = 0.1, return_timestamps: bool = False, ) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. @@ -88,6 +90,8 @@ def fast_beam_search_one_best( max_states=max_states, max_contexts=max_contexts, temperature=temperature, + subtract_ilme=subtract_ilme, + ilme_scale=ilme_scale, ) best_path = one_best_decoding(lattice) @@ -428,6 +432,8 @@ def fast_beam_search( max_states: int, max_contexts: int, temperature: float = 1.0, + subtract_ilme: bool = False, + ilme_scale: float = 0.1, ) -> k2.Fsa: """It limits the maximum number of symbols per frame to 1. @@ -498,6 +504,17 @@ def fast_beam_search( ) logits = logits.squeeze(1).squeeze(1) log_probs = (logits / temperature).log_softmax(dim=-1) + if subtract_ilme: + ilme_logits = model.joiner( + torch.zeros_like( + current_encoder_out, device=current_encoder_out.device + ).unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + ilme_logits = ilme_logits.squeeze(1).squeeze(1) + ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1) + log_probs -= ilme_scale * ilme_log_probs decoding_streams.advance(log_probs) decoding_streams.terminate_and_flush_to_streams() lattice = decoding_streams.format_output(encoder_out_lens.tolist()) @@ -1244,7 +1261,7 @@ def modified_beam_search_lm_rescore( # get the best hyp with different lm_scale for lm_scale in lm_scale_list: - key = f"nnlm_scale_{lm_scale}" + key = f"nnlm_scale_{lm_scale:.2f}" tot_scores = am_scores.values + lm_scores * lm_scale ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) max_indexes = ragged_tot_scores.argmax().tolist() @@ -1257,6 +1274,222 @@ def modified_beam_search_lm_rescore( return ans +def modified_beam_search_lm_rescore_LODR( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + LODR_lm: NgramLm, + sp: spm.SentencePieceProcessor, + lm_scale_list: List[int], + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + Rescore the final results with RNNLM and return the one with the highest score + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + LM: + A neural network language model + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # get the am_scores for n-best list + hyps_shape = get_hyps_shape(B) + am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) + am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) + + # now LM rescore + # prepare input data to LM + candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] + possible_seqs = k2.RaggedTensor(candidate_seqs) + row_splits = possible_seqs.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) + possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) + sentence_token_lengths += 1 + + x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) + y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) + x = x.to(device).to(torch.int64) + y = y.to(device).to(torch.int64) + sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) + + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + assert lm_scores.ndim == 2 + lm_scores = -1 * lm_scores.sum(dim=1) + + # now LODR scores + import math + + LODR_scores = [] + for seq in candidate_seqs: + tokens = " ".join(sp.id_to_piece(seq)) + LODR_scores.append(LODR_lm.score(tokens)) + LODR_scores = torch.tensor(LODR_scores).to(device) * math.log( + 10 + ) # arpa scores are 10-based + assert lm_scores.shape == LODR_scores.shape + + ans = {} + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + + LODR_scale_list = [0.05 * i for i in range(1, 20)] + # get the best hyp with different lm_scale and lodr_scale + for lm_scale in lm_scale_list: + for lodr_scale in LODR_scale_list: + key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}" + tot_scores = ( + am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale + ) + ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) + max_indexes = ragged_tot_scores.argmax().tolist() + unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] + hyps = [] + for idx in unsorted_indices: + hyps.append(unsorted_hyps[idx]) + + ans[key] = hyps + return ans + + def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py new file mode 100755 index 000000000..32eb9eda3 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py @@ -0,0 +1,681 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-25-avg-5.pt" + +cd exp +ln -s pretrained-epoch-25-avg-5.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless5/export-onnx-streaming.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir $repo/exp \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. + +You can find the exported models in +https://huggingface.co/csukuangfj/sherpa-onnx-streaming-conformer-en-2023-05-09 +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from conformer import Conformer +from onnxruntime.quantization import QuantType, quantize_dynamic +from decoder import Decoder +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Conformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Conformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Conformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + self.num_encoder_layers = encoder.encoder_layers + self.encoder_dim = encoder.d_model + self.cnn_module_kernel = encoder.cnn_module_kernel + + # Note you can tune these values + self.left_context = 64 # after subsampling + self.chunk_size = 16 # after subsampling + self.right_context = 0 # after subsampling + + subsampling_factor = 4 + self.pad_length = (self.right_context + 2) * subsampling_factor + 3 + + self.T = (self.chunk_size * subsampling_factor) + self.pad_length + self.decode_chunk_len = self.chunk_size * subsampling_factor + + def forward( + self, + x: torch.Tensor, + cached_attn: torch.Tensor, + cached_conv: torch.Tensor, + processed_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Conformer.forward + + Args: + x: + A 3-D tensor of shape (N, self.T, C) + cached_attn: + A 3-D tensor of shape + (num_encoder_layers, self.left_context, N, self.encoder_dim) + cached_conv: + A 3-D tensor of shape + (num_encoder_layers, self.cnn_module_kernel-1, N, self.encoder_dim) + processed_lens: + A 1-D tensor of shape (N,). It contains number of processed frames + after subsampling. Its dtype is torch.int64. + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, self.chunk_size, joiner_dim) + - new_cached_attn, it has the same shape as cached_attn + - new_cached_conv, it has the same shape as cached_conv + """ + assert x.size(1) == self.T, (x.shape, self.T) + N = x.size(0) + x_lens = torch.full((N,), fill_value=self.T, device=x.device, dtype=torch.int64) + + ( + encoder_out, + _, + [new_cached_attn, new_cached_conv], + ) = self.encoder.streaming_forward( + x, + x_lens, + states=[cached_attn, cached_conv], + processed_lens=processed_lens, + left_context=self.left_context, + right_context=self.right_context, + chunk_size=self.chunk_size, + ) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, new_cached_attn, new_cached_conv + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + N = 1 + x = torch.zeros(N, encoder_model.T, 80, dtype=torch.float32) + cached_attn = torch.zeros( + encoder_model.num_encoder_layers, + encoder_model.left_context, + N, + encoder_model.encoder_dim, + ) + cached_conv = torch.zeros( + encoder_model.num_encoder_layers, + encoder_model.cnn_module_kernel - 1, + N, + encoder_model.encoder_dim, + ) + processed_lens = torch.zeros((N,), dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, cached_attn, cached_conv, processed_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "cached_attn", "cached_conv", "processed_lens"], + output_names=["encoder_out", "new_cached_attn", "new_cached_conv"], + dynamic_axes={ + "x": {0: "N"}, + "cached_attn": {2: "N"}, + "cached_conv": {2: "N"}, + "processed_lens": {0: "N"}, + "encoder_out": {0: "N"}, + "new_cached_attn": {2: "N"}, + "new_cached_conv": {2: "N"}, + }, + ) + + meta_data = { + "model_type": "conformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless5", + "pad_length": str(encoder_model.pad_length), + "decode_chunk_len": str(encoder_model.decode_chunk_len), + "encoder_dim": str(encoder_model.encoder_dim), + "num_encoder_layers": str(encoder_model.num_encoder_layers), + "cnn_module_kernel": str(encoder_model.cnn_module_kernel), + "left_context": str(encoder_model.left_context), + "right_context": str(encoder_model.right_context), + "chunk_size": str(encoder_model.chunk_size), + "T": str(encoder_model.T), + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if not params.causal_convolution: + logging.info("Seting causal_convolution to True for exporting streaming models") + params.causal_convolution = True + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum(p.numel() for p in model.parameters()) + logging.info(f"Number of model parameters: {num_param}") + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py new file mode 100755 index 000000000..29be4c655 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script loads ONNX models exported by ./export-onnx.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-25-avg-5.pt" +cd exp +ln -s pretrained-epoch-25-avg-5.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless5/export-onnx-streaming.py \ + --bpe-model ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/exp \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + +It will generate the following 3 files in $repo/exp + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./pruned_transducer_stateless5/onnx_pretrained-streaming.py \ + --encoder-model-filename ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/exp/joiner-epoch-99-avg-1.onnx \ + --tokens=./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/data/lang_bpe_500/tokens.txt \ + ./icefall_librispeech_streaming_pruned_transducer_stateless5_20220729/test_waves/1221-135766-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. + +You can find the exported models in +https://huggingface.co/csukuangfj/sherpa-onnx-streaming-conformer-en-2023-05-09 +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + print(encoder_meta) + + model_type = encoder_meta["model_type"] + assert model_type == "conformer", model_type + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + T = int(encoder_meta["T"]) + pad_length = int(encoder_meta["pad_length"]) + + encoder_dim = int(encoder_meta["encoder_dim"]) + cnn_module_kernel = int(encoder_meta["cnn_module_kernel"]) + left_context = int(encoder_meta["left_context"]) + num_encoder_layers = int(encoder_meta["num_encoder_layers"]) + + self.cached_attn = torch.zeros( + num_encoder_layers, + left_context, + batch_size, + encoder_dim, + ).numpy() + self.cached_conv = torch.zeros( + num_encoder_layers, + cnn_module_kernel - 1, + batch_size, + encoder_dim, + ).numpy() + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"pad_length: {pad_length}") + logging.info(f"encoder_dim: {encoder_dim}") + logging.info(f"cnn_module_kernel: {cnn_module_kernel}") + logging.info(f"left_context: {left_context}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + + self.segment = T + self.offset = decode_chunk_len + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, x: torch.Tensor, processed_lens: int + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + assert x.size(0) == 1 + encoder_input = { + "x": x.numpy(), + "cached_attn": self.cached_attn, + "cached_conv": self.cached_conv, + "processed_lens": torch.full( + (1,), fill_value=processed_lens, dtype=torch.int64 + ).numpy(), + } + encoder_output = ["encoder_out", "new_cached_attn", "new_cached_conv"] + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + self.cached_attn = states[0] + self.cached_conv = states[1] + + def run_encoder(self, x: torch.Tensor, num_processed_frames: int) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, self.T, C). It only implements N == 1 + num_processed_frames: + Number of processed frames before subsampling. + Returns: + Return a 3-D tensor of shape (N, chunk_size, joiner_dim) + """ + # assume subsampling_factor is 4 + num_processed_frames = num_processed_frames // 4 + encoder_input, encoder_output_names = self._build_encoder_input_output( + x, num_processed_frames + ) + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(1.0 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames, num_processed_frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py b/egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py index dcb4cd141..07c7126fa 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py @@ -25,29 +25,53 @@ from lhotse import CutSet, load_manifest_lazy class MultiDataset: - def __init__(self, manifest_dir: str): + def __init__(self, manifest_dir: str, cv_manifest_dir: str): """ Args: manifest_dir: It is expected to contain the following files: - - multidataset_split_1998/multidataset/multidataset_cuts_train.*.jsonl.gz + - librispeech_cuts_train-all-shuf.jsonl.gz + - gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz + + cv_manifest_dir: + It is expected to contain the following files: + + - cv-en_cuts_train.jsonl.gz """ self.manifest_dir = Path(manifest_dir) + self.cv_manifest_dir = Path(cv_manifest_dir) def train_cuts(self) -> CutSet: logging.info("About to get multidataset train cuts") - filenames = glob.glob( - f"{self.manifest_dir}/multidataset_split_1998/multidataset/multidataset_cuts_train.*.jsonl.gz" + # LibriSpeech + logging.info(f"Loading LibriSpeech in lazy mode") + librispeech_cuts = load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" ) - pattern = re.compile(r"multidataset_cuts_train.([0-9]+).jsonl.gz") + # GigaSpeech + filenames = glob.glob( + f"{self.manifest_dir}/gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz" + ) + + pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) sorted_filenames = [f[1] for f in idx_filenames] - logging.info(f"Loading {len(sorted_filenames)} splits") + logging.info(f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode") - return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) + gigaspeech_cuts = lhotse.combine( + lhotse.load_manifest_lazy(p) for p in sorted_filenames + ) + + # CommonVoice + logging.info(f"Loading CommonVoice in lazy mode") + commonvoice_cuts = load_manifest_lazy( + self.cv_manifest_dir / f"cv-en_cuts_train.jsonl.gz" + ) + + return CutSet.mux(librispeech_cuts, gigaspeech_cuts, commonvoice_cuts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 01c9500ce..ed6dfc28f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -60,13 +60,13 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from multidataset import MultiDataset from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer +from multidataset import MultiDataset from optim import Eden, ScaledAdam from torch import Tensor from torch.cuda.amp import GradScaler @@ -1053,10 +1053,12 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) if params.use_multidataset: - multidataset = MultiDataset(params.manifest_dir) + multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir) train_cuts = multidataset.train_cuts() else: - if params.full_libri: + if params.mini_libri: + train_cuts = librispeech.train_clean_5_cuts() + elif params.full_libri: train_cuts = librispeech.train_all_shuf_cuts() else: train_cuts = librispeech.train_clean_100_cuts() @@ -1108,8 +1110,11 @@ def run(rank, world_size, args): train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() + if params.mini_libri: + valid_cuts = librispeech.dev_clean_2_cuts() + else: + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) if not params.use_multidataset and not params.print_diagnostics: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py index 8aa0d8689..3444f8193 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -123,10 +123,13 @@ from beam_search import ( greedy_search_batch, modified_beam_search, modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -134,7 +137,6 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon -from icefall.lm_wrapper import LmScorer from icefall.utils import ( AttributeDict, setup_logger, @@ -336,6 +338,21 @@ def get_parser(): """, ) + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + add_model_arguments(parser) return parser @@ -349,6 +366,8 @@ def decode_one_batch( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -483,6 +502,18 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_lm_rescore": lm_scale_list = [0.01 * i for i in range(10, 50)] ans_dict = modified_beam_search_lm_rescore( @@ -493,6 +524,18 @@ def decode_one_batch( LM=LM, lm_scale_list=lm_scale_list, ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) else: batch_size = encoder_out.size(0) @@ -531,7 +574,10 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} - elif params.decoding_method == "modified_beam_search_lm_rescore": + elif params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): ans = dict() assert ans_dict is not None for key, hyps in ans_dict.items(): @@ -550,6 +596,8 @@ def decode_dataset( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -568,6 +616,8 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + ngram_lm: + A n-gram LM to be used for LODR. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -600,6 +650,8 @@ def decode_dataset( word_table=word_table, batch=batch, LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) for name, hyps in hyps_dict.items(): @@ -677,8 +729,10 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_LODR", "modified_beam_search_lm_shallow_fusion", "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", ) params.res_dir = params.exp_dir / params.decoding_method @@ -822,7 +876,12 @@ def main(): model.eval() # only load the neural network LM if required - if params.use_shallow_fusion or "lm" in params.decoding_method: + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): LM = LmScorer( lm_type=params.lm_type, params=params, @@ -834,6 +893,35 @@ def main(): else: LM = None + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -866,8 +954,10 @@ def main(): test_sets = ["test-clean", "test-other"] test_dl = [test_clean_dl, test_other_dl] + import time for test_set, test_dl in zip(test_sets, test_dl): + start = time.time() results_dict = decode_dataset( dl=test_dl, params=params, @@ -876,7 +966,10 @@ def main(): word_table=word_table, decoding_graph=decoding_graph, LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) + logging.info(f"Elasped time for {test_set}: {time.time() - start}") save_results( params=params, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index b2f9ffc09..90428133d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1049,10 +1049,13 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + if params.mini_libri: + train_cuts = librispeech.train_clean_5_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1104,8 +1107,11 @@ def run(rank, world_size, args): train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() + if params.mini_libri: + valid_cuts = librispeech.dev_clean_2_cuts() + else: + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) # if not params.print_diagnostics: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index c5787835d..c47964b07 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -86,8 +86,16 @@ class LibriSpeechAsrDataModule: "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + group.add_argument( "--manifest-dir", type=Path, @@ -393,6 +401,13 @@ class LibriSpeechAsrDataModule: ) return test_dl + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") @@ -424,6 +439,13 @@ class LibriSpeechAsrDataModule: self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" ) + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + @lru_cache() def dev_clean_cuts(self) -> CutSet: logging.info("About to get dev-clean cuts") diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py old mode 100644 new mode 100755 index 32c248d7e..c8562f4fb --- a/egs/timit/ASR/local/compile_hlg.py +++ b/egs/timit/ASR/local/compile_hlg.py @@ -100,7 +100,11 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: logging.info("Removing disambiguation symbols on LG") - LG.labels[LG.labels >= first_token_disambig_id] = 0 + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 diff --git a/egs/timit/ASR/local/compute_fbank_timit.py b/egs/timit/ASR/local/compute_fbank_timit.py old mode 100644 new mode 100755 diff --git a/egs/timit/ASR/local/prepare_lang.py b/egs/timit/ASR/local/prepare_lang.py old mode 100644 new mode 100755 diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py old mode 100644 new mode 100755 diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh index 148a9f51b..f25fe5add 100644 --- a/egs/timit/ASR/prepare.sh +++ b/egs/timit/ASR/prepare.sh @@ -59,7 +59,9 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then # using: `sudo apt-get install git-lfs && git-lfs install` [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm git clone https://huggingface.co/luomingshuang/timit_lm $dl_dir/lm - cd $dl_dir/lm && git lfs pull + pushd $dl_dir/lm + git lfs pull + popd fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py index 7234ca929..e0a94bf08 100755 --- a/egs/yesno/ASR/local/compile_hlg.py +++ b/egs/yesno/ASR/local/compile_hlg.py @@ -78,10 +78,11 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: logging.info("Removing disambiguation symbols on LG") - LG.labels[LG.labels >= first_token_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set LG.properties to None - LG.__dict__["_properties"] = None + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels assert isinstance(LG.aux_labels, k2.RaggedTensor) LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 diff --git a/icefall/__init__.py b/icefall/__init__.py index 82d21706c..5d846b41d 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -8,6 +8,12 @@ from . import ( utils ) +from .byte_utils import ( + byte_decode, + byte_encode, + smart_byte_decode, +) + from .checkpoint import ( average_checkpoints, find_checkpoints, @@ -49,6 +55,7 @@ from .utils import ( get_alignments, get_executor, get_texts, + is_cjk, is_jit_tracing, is_module_available, l1_norm, @@ -64,6 +71,7 @@ from .utils import ( store_transcripts, str2bool, subsequent_chunk_mask, + tokenize_by_CJK_char, write_error_stats, ) diff --git a/icefall/byte_utils.py b/icefall/byte_utils.py new file mode 100644 index 000000000..7ee84ad27 --- /dev/null +++ b/icefall/byte_utils.py @@ -0,0 +1,311 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# This file was copied and modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_utils.py + +import re +import unicodedata + + +WHITESPACE_NORMALIZER = re.compile(r"\s+") +SPACE = chr(32) +SPACE_ESCAPE = chr(9601) + +PRINTABLE_BASE_CHARS = [ + 256, + 257, + 258, + 259, + 260, + 261, + 262, + 263, + 264, + 265, + 266, + 267, + 268, + 269, + 270, + 271, + 272, + 273, + 274, + 275, + 276, + 277, + 278, + 279, + 280, + 281, + 282, + 283, + 284, + 285, + 286, + 287, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 288, + 289, + 290, + 291, + 292, + 293, + 294, + 295, + 296, + 297, + 298, + 299, + 300, + 301, + 302, + 303, + 304, + 305, + 308, + 309, + 310, + 311, + 312, + 313, + 314, + 315, + 316, + 317, + 318, + 321, + 322, + 323, + 324, + 325, + 326, + 327, + 328, + 330, + 331, + 332, + 333, + 334, + 335, + 336, + 337, + 338, + 339, + 340, + 341, + 342, + 343, + 344, + 345, + 346, + 347, + 348, + 349, + 350, + 351, + 352, + 353, + 354, + 355, + 356, + 357, + 358, + 359, + 360, + 361, + 362, + 363, + 364, + 365, + 366, + 367, + 368, + 369, + 370, + 371, + 372, + 373, + 374, + 375, + 376, + 377, + 378, + 379, + 380, + 381, + 382, + 384, + 385, + 386, + 387, + 388, + 389, + 390, + 391, + 392, + 393, + 394, + 395, + 396, + 397, + 398, + 399, + 400, + 401, + 402, + 403, + 404, + 405, + 406, + 407, + 408, + 409, + 410, + 411, + 412, + 413, + 414, + 415, + 416, + 417, + 418, + 419, + 420, + 421, + 422, +] + +for c in PRINTABLE_BASE_CHARS: + assert unicodedata.normalize("NFKC", chr(c)) == chr(c), c + +BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)} +BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()} + + +def byte_encode(x: str) -> str: + normalized = WHITESPACE_NORMALIZER.sub(SPACE, x) + return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")]) + + +def byte_decode(x: str) -> str: + try: + return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8") + except ValueError: + return "" + + +def smart_byte_decode(x: str) -> str: + output = byte_decode(x) + if output == "": + # DP the best recovery (max valid chars) if it's broken + n_bytes = len(x) + f = [0 for _ in range(n_bytes + 1)] + pt = [0 for _ in range(n_bytes + 1)] + for i in range(1, n_bytes + 1): + f[i], pt[i] = f[i - 1], i - 1 + for j in range(1, min(4, i) + 1): + if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0: + f[i], pt[i] = f[i - j] + 1, i - j + cur_pt = n_bytes + while cur_pt > 0: + if f[cur_pt] == f[pt[cur_pt]] + 1: + output = byte_decode(x[pt[cur_pt] : cur_pt]) + output + cur_pt = pt[cur_pt] + return output diff --git a/icefall/rnn_lm/.gitignore b/icefall/rnn_lm/.gitignore new file mode 100644 index 000000000..877fb1e18 --- /dev/null +++ b/icefall/rnn_lm/.gitignore @@ -0,0 +1 @@ +icefall-librispeech-rnn-lm diff --git a/icefall/rnn_lm/check-onnx-streaming.py b/icefall/rnn_lm/check-onnx-streaming.py new file mode 100755 index 000000000..d51a4b76b --- /dev/null +++ b/icefall/rnn_lm/check-onnx-streaming.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation + +""" +Usage: + +./check-onnx-streaming.py \ + --jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \ + --onnx ./icefall-librispeech-rnn-lm/exp/with-state-epoch-99-avg-1.onnx + +Note: You can download pre-trained models from +https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + +""" + +import argparse +import logging +from typing import Tuple + +import onnxruntime as ort +import torch + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx", + required=True, + type=str, + help="Path to the onnx model", + ) + + return parser + + +class OnnxModel: + def __init__(self, filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.model = ort.InferenceSession( + filename, + sess_options=session_opts, + ) + + meta_data = self.model.get_modelmeta().custom_metadata_map + self.sos_id = int(meta_data["sos_id"]) + self.eos_id = int(meta_data["eos_id"]) + self.vocab_size = int(meta_data["vocab_size"]) + self.num_layers = int(meta_data["num_layers"]) + self.hidden_size = int(meta_data["hidden_size"]) + print(meta_data) + + def __call__( + self, x: torch.Tensor, y: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + self.model.get_outputs()[2].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: y.numpy(), + self.model.get_inputs()[2].name: h0.numpy(), + self.model.get_inputs()[3].name: c0.numpy(), + }, + ) + return ( + torch.from_numpy(out[0]), + torch.from_numpy(out[1]), + torch.from_numpy(out[2]), + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_model = torch.jit.load(args.jit).cpu() + onnx_model = OnnxModel(args.onnx) + N = torch.arange(1, 5).tolist() + + num_layers = onnx_model.num_layers + hidden_size = onnx_model.hidden_size + + for n in N: + L = torch.randint(low=1, high=100, size=(1,)).item() + x = torch.randint( + low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64 + ) + h0 = torch.rand(num_layers, n, hidden_size) + c0 = torch.rand(num_layers, n, hidden_size) + + torch_log_prob, torch_h0, torch_c0 = torch_model.score_token_onnx(x, h0, c0) + onnx_log_prob, onnx_h0, onnx_c0 = onnx_model(x, h0, c0) + + for torch_v, onnx_v in zip( + (torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0) + ): + + assert torch.allclose(torch_v, onnx_v, atol=1e-5), ( + torch_v.shape, + onnx_v.shape, + (torch_v - onnx_v).abs().max(), + ) + print(n, L, torch_v.sum(), onnx_v.sum()) + + +if __name__ == "__main__": + torch.manual_seed(20230423) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/rnn_lm/check-onnx.py b/icefall/rnn_lm/check-onnx.py new file mode 100755 index 000000000..24c5395f8 --- /dev/null +++ b/icefall/rnn_lm/check-onnx.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation + +""" +Usage: + +./check-onnx.py \ + --jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \ + --onnx ./icefall-librispeech-rnn-lm/exp/no-state-epoch-99-avg-1.onnx + +Note: You can download pre-trained models from +https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + +""" + +import argparse +import logging + +import onnxruntime as ort +import torch + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx", + required=True, + type=str, + help="Path to the onnx model", + ) + + return parser + + +class OnnxModel: + def __init__(self, filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.model = ort.InferenceSession( + filename, + sess_options=session_opts, + ) + + meta_data = self.model.get_modelmeta().custom_metadata_map + self.sos_id = int(meta_data["sos_id"]) + self.eos_id = int(meta_data["eos_id"]) + self.vocab_size = int(meta_data["vocab_size"]) + print(meta_data) + + def __call__(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_model = torch.jit.load(args.jit).cpu() + onnx_model = OnnxModel(args.onnx) + N = torch.arange(1, 5).tolist() + + for n in N: + L = torch.randint(low=1, high=100, size=(1,)).item() + x = torch.randint( + low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64 + ) + x_lens = torch.full((n,), fill_value=L, dtype=torch.int64) + if n > 1: + x_lens[0] = L // 2 + 1 + + sos = torch.full((1,), fill_value=onnx_model.sos_id).expand(n, 1) + sos_x = torch.cat([sos, x], dim=1) + + pad_col = torch.zeros((1,), dtype=x.dtype).expand(n, 1) + x_eos = torch.cat([x, pad_col], dim=1) + + row_index = torch.arange(0, n, dtype=x.dtype) + x_eos[row_index, x_lens] = onnx_model.eos_id + + torch_nll = torch_model(sos_x, x_eos, x_lens + 1).sum(dim=-1) + onnx_nll = onnx_model(x, x_lens) + # Note: For int8 models, the differences may be quite large, + # e.g., within 0.9 + assert torch.allclose(torch_nll, onnx_nll), ( + torch_nll, + onnx_nll, + ) + print(n, L, torch_nll, onnx_nll) + + +if __name__ == "__main__": + torch.manual_seed(20230420) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/rnn_lm/export-onnx.py b/icefall/rnn_lm/export-onnx.py new file mode 100755 index 000000000..dfede708b --- /dev/null +++ b/icefall/rnn_lm/export-onnx.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation + +import argparse +import logging +from pathlib import Path + +import onnx +import torch +from model import RnnLmModel +from onnxruntime.quantization import QuantType, quantize_dynamic + +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.utils import AttributeDict, str2bool +from typing import Dict +from train import get_params + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +# A wrapper for RnnLm model to simpily the C++ calling code +# when exporting the model to ONNX. +# +# TODO(fangjun): The current wrapper works only for non-streaming ASR +# since we don't expose the LM state and it is used to score +# a complete sentence at once. +class RnnLmModelWrapper(torch.nn.Module): + def __init__(self, model: RnnLmModel, sos_id: int, eos_id: int): + super().__init__() + self.model = model + self.sos_id = sos_id + self.eos_id = eos_id + + def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 2-D tensor of shape (N, L) with dtype torch.int64. + It does not contain SOS or EOS. We will add SOS and EOS inside + this function. + x_lens: + A 1-D tensor of shape (N,) with dtype torch.int64. It contains + number of valid tokens in ``x`` before padding. + Returns: + Return a 1-D tensor of shape (N,) containing negative loglikelihood. + Its dtype is torch.float32 + """ + N = x.size(0) + + sos_tensor = torch.full((1,), fill_value=self.sos_id, dtype=x.dtype).expand( + N, 1 + ) + sos_x = torch.cat([sos_tensor, x], dim=1) + + pad_col = torch.zeros((1,), dtype=x.dtype).expand(N, 1) + x_eos = torch.cat([x, pad_col], dim=1) + + row_index = torch.arange(0, N, dtype=x.dtype) + x_eos[row_index, x_lens] = self.eos_id + + # use x_lens + 1 here since we prepended x with sos + return ( + self.model(x=sos_x, y=x_eos, lengths=x_lens + 1) + .to(torch.float32) + .sum(dim=1) + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=29, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=5, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=3, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=True, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnn_lm/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + return parser + + +def export_without_state( + model: RnnLmModel, + filename: str, + params: AttributeDict, + opset_version: int, +): + model_wrapper = RnnLmModelWrapper( + model, + sos_id=params.sos_id, + eos_id=params.eos_id, + ) + + N = 1 + L = 20 + x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) + x_lens = torch.full((N,), fill_value=L, dtype=torch.int64) + + # Note(fangjun): The following warnings can be ignored. + # We can use ./check-onnx.py to validate the exported model with batch_size > 1 + """ + torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX + with a batch_size other than 1, with a variable length with LSTM can cause + an error when running the ONNX model with a different batch size. Make sure + to save the model with a batch size of 1, or define the initial states + (h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX + with a batch_size other than 1, " + + """ + + torch.onnx.export( + model_wrapper, + (x, x_lens), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["nll"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "x_lens": {0: "N"}, + "nll": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "rnnlm", + "version": "1", + "model_author": "k2-fsa", + "comment": "rnnlm without state", + "sos_id": str(params.sos_id), + "eos_id": str(params.eos_id), + "vocab_size": str(params.vocab_size), + "url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=filename, meta_data=meta_data) + + +def export_with_state( + model: RnnLmModel, + filename: str, + params: AttributeDict, + opset_version: int, +): + N = 1 + L = 20 + num_layers = model.rnn.num_layers + hidden_size = model.rnn.hidden_size + embedding_dim = model.embedding_dim + + x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) + h0 = torch.zeros(num_layers, N, hidden_size) + c0 = torch.zeros(num_layers, N, hidden_size) + + # Note(fangjun): The following warnings can be ignored. + # We can use ./check-onnx.py to validate the exported model with batch_size > 1 + """ + torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX + with a batch_size other than 1, with a variable length with LSTM can cause + an error when running the ONNX model with a different batch size. Make sure + to save the model with a batch size of 1, or define the initial states + (h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX + with a batch_size other than 1, " + + """ + + torch.onnx.export( + model, + (x, h0, c0), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "h0", "c0"], + output_names=["log_softmax", "next_h0", "next_c0"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "h0": {1: "N"}, + "c0": {1: "N"}, + "log_softmax": {0: "N"}, + "next_h0": {1: "N"}, + "next_c0": {1: "N"}, + }, + ) + + meta_data = { + "model_type": "rnnlm", + "version": "1", + "model_author": "k2-fsa", + "comment": "rnnlm state", + "sos_id": str(params.sos_id), + "eos_id": str(params.eos_id), + "vocab_size": str(params.vocab_size), + "num_layers": str(num_layers), + "hidden_size": str(hidden_size), + "embedding_dim": str(embedding_dim), + "url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + device = torch.device("cpu") + logging.info(f"device: {device}") + + model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + hidden_dim=params.hidden_dim, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + ) + + model.to(device) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.to("cpu") + model.eval() + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting model without state") + filename = params.exp_dir / f"no-state-{suffix}.onnx" + export_without_state( + model=model, + filename=filename, + params=params, + opset_version=opset_version, + ) + + filename_int8 = params.exp_dir / f"no-state-{suffix}.int8.onnx" + quantize_dynamic( + model_input=filename, + model_output=filename_int8, + weight_type=QuantType.QInt8, + ) + + # now for streaming export + saved_forward = model.__class__.forward + model.__class__.forward = model.__class__.score_token_onnx + streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx" + export_with_state( + model=model, + filename=streaming_filename, + params=params, + opset_version=opset_version, + ) + model.__class__.forward = saved_forward + + streaming_filename_int8 = params.exp_dir / f"with-state-{suffix}.int8.onnx" + quantize_dynamic( + model_input=streaming_filename, + model_output=streaming_filename_int8, + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/rnn_lm/export-onnx.sh b/icefall/rnn_lm/export-onnx.sh new file mode 100755 index 000000000..6e3262b5e --- /dev/null +++ b/icefall/rnn_lm/export-onnx.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +# We use the model from +# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main +# as an example + +export CUDA_VISIBLE_DEVICES= + +if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + pushd icefall-librispeech-rnn-lm/exp + git lfs pull --include "pretrained.pt" + ln -s pretrained.pt epoch-99.pt + popd +fi + +python3 ./export-onnx.py \ + --exp-dir ./icefall-librispeech-rnn-lm/exp \ + --epoch 99 \ + --avg 1 \ + --vocab-size 500 \ + --embedding-dim 2048 \ + --hidden-dim 2048 \ + --num-layers 3 \ + --tie-weights 1 + diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py index a8598a1ce..dadf23009 100644 --- a/icefall/rnn_lm/export.py +++ b/icefall/rnn_lm/export.py @@ -26,7 +26,7 @@ import torch from model import RnnLmModel from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import AttributeDict, load_averaged_model, str2bool +from icefall.utils import AttributeDict, str2bool def get_parser(): @@ -118,6 +118,7 @@ def get_parser(): return parser +@torch.no_grad() def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) @@ -180,6 +181,10 @@ def main(): if params.jit: logging.info("Using torch.jit.script") + + model.__class__.score_token_onnx = torch.jit.export( + model.__class__.score_token_onnx + ) model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" model.save(str(filename)) diff --git a/icefall/rnn_lm/export.sh b/icefall/rnn_lm/export.sh new file mode 100755 index 000000000..678bc294e --- /dev/null +++ b/icefall/rnn_lm/export.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +# We use the model from +# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main +# as an example + +export CUDA_VISIBLE_DEVICES= + +if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + pushd icefall-librispeech-rnn-lm/exp + git lfs pull --include "pretrained.pt" + ln -s pretrained.pt epoch-99.pt + popd +fi + +python3 ./export.py \ + --exp-dir ./icefall-librispeech-rnn-lm/exp \ + --epoch 99 \ + --avg 1 \ + --vocab-size 500 \ + --embedding-dim 2048 \ + --hidden-dim 2048 \ + --num-layers 3 \ + --tie-weights 1 \ + --jit 1 + diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index ebb3128e3..5eacf5d40 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import Tuple import torch import torch.nn.functional as F @@ -47,6 +48,11 @@ class RnnLmModel(torch.nn.Module): and https://arxiv.org/abs/1611.01462 """ super().__init__() + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.tie_weights = tie_weights self.input_embedding = torch.nn.Embedding( num_embeddings=vocab_size, @@ -74,6 +80,46 @@ class RnnLmModel(torch.nn.Module): self.cache = {} + def streaming_forward( + self, + x: torch.Tensor, + y: torch.Tensor, + h0: torch.Tensor, + c0: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 2-D tensor of shape (N, L). We won't prepend it with SOS. + y: + A 2-D tensor of shape (N, L). We won't append it with EOS. + h0: + A 3-D tensor of shape (num_layers, N, hidden_size). + (If proj_size > 0, then it is (num_layers, N, proj_size)) + c0: + A 3-D tensor of shape (num_layers, N, hidden_size). + Returns: + Return a tuple containing 3 tensors: + - negative loglike (nll), a 1-D tensor of shape (N,) + - next_h0, a 3-D tensor with the same shape as h0 + - next_c0, a 3-D tensor with the same shape as c0 + """ + assert x.ndim == y.ndim == 2, (x.ndim, y.ndim) + assert x.shape == y.shape, (x.shape, y.shape) + + # embedding is of shape (N, L, embedding_dim) + embedding = self.input_embedding(x) + # Note: We use batch_first==True + rnn_out, (next_h0, next_c0) = self.rnn(embedding, (h0, c0)) + logits = self.output_linear(rnn_out) + nll_loss = F.cross_entropy( + logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none" + ) + + batch_size = x.size(0) + nll_loss = nll_loss.reshape(batch_size, -1).sum(dim=1) + return nll_loss, next_h0, next_c0 + def forward( self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor ) -> torch.Tensor: @@ -188,6 +234,33 @@ class RnnLmModel(torch.nn.Module): return logits[:, 0].log_softmax(-1), states + def score_token_onnx( + self, + x: torch.Tensor, + state_h: torch.Tensor, + state_c: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Score a batch of tokens, i.e each sample in the batch should be a + single token. For example, x = torch.tensor([[5],[10],[20]]) + + + Args: + x (torch.Tensor): + A batch of tokens + state_h: + state h of RNN has the shape of (num_layers, bs, hidden_dim) + state_c: + state c of RNN has the shape of (num_layers, bs, hidden_dim) + + Returns: + _type_: _description_ + """ + embedding = self.input_embedding(x) + rnn_out, (next_h0, next_c0) = self.rnn(embedding, (state_h, state_c)) + logits = self.output_linear(rnn_out) + + return logits[:, 0].log_softmax(-1), next_h0, next_c0 + def forward_with_state( self, tokens, token_lens, sos_id, eos_id, blank_id, state=None ): diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 91df4f921..0f0887859 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -22,6 +22,7 @@ Usage: --world-size 2 \ --num-epochs 1 \ --use-fp16 0 \ + --tie-weights 0 \ --embedding-dim 800 \ --hidden-dim 200 \ --num-layers 2 \ diff --git a/icefall/utils.py b/icefall/utils.py index 1fd9156bd..4aa8197ad 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1306,6 +1306,31 @@ def tokenize_by_bpe_model( return txt_with_bpe +def tokenize_by_CJK_char(line: str) -> str: + """ + Tokenize a line of text with CJK char. + + Note: All return charaters will be upper case. + + Example: + input = "你好世界是 hello world 的中文" + output = "你 好 世 界 是 HELLO WORLD 的 中 文" + + Args: + line: + The input text. + + Return: + A new string tokenize by CJK char. + """ + # The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py + pattern = re.compile( + r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])" + ) + chars = pattern.split(line.strip().upper()) + return " ".join([w.strip() for w in chars if w.strip()]) + + def display_and_save_batch( batch: dict, params: AttributeDict, @@ -1764,3 +1789,34 @@ def parse_fsa_timestamps_and_texts( utt_time_pairs.append(list(zip(start, end))) return utt_time_pairs, utt_words + + +# Copied from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py +def is_cjk(character): + """ + Python port of Moses' code to check for CJK character. + + >>> is_cjk(u'\u33fe') + True + >>> is_cjk(u'\uFE5F') + False + + :param character: The character that needs to be checked. + :type character: char + :return: bool + """ + return any( + [ + start <= ord(character) <= end + for start, end in [ + (4352, 4607), + (11904, 42191), + (43072, 43135), + (44032, 55215), + (63744, 64255), + (65072, 65103), + (65381, 65500), + (131072, 196607), + ] + ] + )