From 11523c5b894f42ded965dcb974fef9a8a8122518 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 6 Jul 2023 19:11:01 +0800 Subject: [PATCH 001/113] Shallow fusion & LODR documentation (#1142) * add shallow fusion documentation * add documentation for LODR * upload docs for LM rescoring --- docs/source/conf.py | 1 + .../decoding-with-langugage-models/LODR.rst | 184 +++++++++++++ .../decoding-with-langugage-models/index.rst | 12 + .../rescoring.rst | 252 ++++++++++++++++++ .../shallow-fusion.rst | 176 ++++++++++++ docs/source/index.rst | 5 + .../librispeech/distillation.rst | 8 +- .../pruned_transducer_stateless.rst | 18 +- .../recipes/Streaming-ASR/introduction.rst | 4 +- .../pruned_transducer_stateless.rst | 10 +- .../librispeech/zipformer_transducer.rst | 4 +- 11 files changed, 652 insertions(+), 22 deletions(-) create mode 100644 docs/source/decoding-with-langugage-models/LODR.rst create mode 100644 docs/source/decoding-with-langugage-models/index.rst create mode 100644 docs/source/decoding-with-langugage-models/rescoring.rst create mode 100644 docs/source/decoding-with-langugage-models/shallow-fusion.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index 6901dec02..0ff3f801c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -86,6 +86,7 @@ rst_epilog = """ .. _git-lfs: https://git-lfs.com/ .. _ncnn: https://github.com/tencent/ncnn .. _LibriSpeech: https://www.openslr.org/12 +.. _Gigaspeech: https://github.com/SpeechColab/GigaSpeech .. _musan: http://www.openslr.org/17/ .. _ONNX: https://github.com/onnx/onnx .. _onnxruntime: https://github.com/microsoft/onnxruntime diff --git a/docs/source/decoding-with-langugage-models/LODR.rst b/docs/source/decoding-with-langugage-models/LODR.rst new file mode 100644 index 000000000..7ffa0c128 --- /dev/null +++ b/docs/source/decoding-with-langugage-models/LODR.rst @@ -0,0 +1,184 @@ +.. _LODR: + +LODR for RNN Transducer +======================= + + +As a type of E2E model, neural transducers are usually considered as having an internal +language model, which learns the language level information on the training corpus. +In real-life scenario, there is often a mismatch between the training corpus and the target corpus space. +This mismatch can be a problem when decoding for neural transducer models with language models as its internal +language can act "against" the external LM. In this tutorial, we show how to use +`Low-order Density Ratio `_ to alleviate this effect to further improve the performance +of langugae model integration. + +.. note:: + + This tutorial is based on the recipe + `pruned_transducer_stateless7_streaming `_, + which is a streaming transducer model trained on `LibriSpeech`_. + However, you can easily apply LODR to other recipes. + If you encounter any problems, please open an issue here `icefall `__. + + +.. note:: + + For simplicity, the training and testing corpus in this tutorial are the same (`LibriSpeech`_). However, + you can change the testing set to any other domains (e.g `GigaSpeech`_) and prepare the language models + using that corpus. + +First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here `_ +to address the language information mismatch between the training +corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain +are acoustically similar, DR derives the following formular for decoding with Bayes' theorem: + +.. math:: + + \text{score}\left(y_u|\mathit{x},y\right) = + \log p\left(y_u|\mathit{x},y_{1:u-1}\right) + + \lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) - + \lambda_2 \log p_{\text{Source LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) + + +where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively. +Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to +shallow fusion is the subtraction of the source domain LM. + +Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is +considered to be weak and can only capture low-level language information. Therefore, `LODR `__ proposed to use +a low-order n-gram LM as an approximation of the ILM of the neural transducer. This leads to the following formula +during decoding for transducer model: + +.. math:: + + \text{score}\left(y_u|\mathit{x},y\right) = + \log p_{rnnt}\left(y_u|\mathit{x},y_{1:u-1}\right) + + \lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) - + \lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right) + +In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR, +the only difference lies in the choice of source domain LM. According to the original `paper `_, +LODR achieves similar performance compared DR in both intra-domain and cross-domain settings. +As a bi-gram is much faster to evaluate, LODR is usually much faster. + +Now, we will show you how to use LODR in ``icefall``. +For illustration purpose, we will use a pre-trained ASR model from this `link `_. +If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`. +The testing scenario here is intra-domain (we decode the model trained on `LibriSpeech`_ on `LibriSpeech`_ testing sets). + +As the initial step, let's download the pre-trained model. + +.. code-block:: bash + + $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + $ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp + $ git lfs pull --include "pretrained.pt" + $ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded + +To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command: + +.. code-block:: bash + + $ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/ + $ ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model False \ + --exp-dir $exp_dir \ + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search + +The following WERs are achieved on test-clean and test-other: + +.. code-block:: text + + $ For test-clean, WER of different settings are: + $ beam_size_4 3.11 best for test-clean + $ For test-other, WER of different settings are: + $ beam_size_4 7.93 best for test-other + +Then, we download the external language model and bi-gram LM that are necessary for LODR. +Note that the bi-gram is estimated on the LibriSpeech 960 hours' text. + +.. code-block:: bash + + $ # download the external LM + $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + $ # create a symbolic link so that the checkpoint can be loaded + $ pushd icefall-librispeech-rnn-lm/exp + $ git lfs pull --include "pretrained.pt" + $ ln -s pretrained.pt epoch-99.pt + $ popd + $ + $ # download the bi-gram + $ git lfs install + $ git clone https://huggingface.co/marcoyang/librispeech_bigram + $ pushd data/lang_bpe_500 + $ ln -s ../../librispeech_bigram/2gram.fst.txt . + $ popd + +Then, we perform LODR decoding by setting ``--decoding-method`` to ``modified_beam_search_lm_LODR``: + +.. code-block:: bash + + $ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp + $ lm_dir=./icefall-librispeech-rnn-lm/exp + $ lm_scale=0.42 + $ LODR_scale=-0.24 + $ ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model False \ + --beam-size 4 \ + --exp-dir $exp_dir \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search_lm_LODR \ + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir $lm_dir \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 3 \ + --lm-vocab-size 500 \ + --tokens-ngram 2 \ + --ngram-lm-scale $LODR_scale + +There are two extra arguments that need to be given when doing LODR. ``--tokens-ngram`` specifies the order of n-gram. As we +are using a bi-gram, we set it to 2. ``--ngram-lm-scale`` is the scale of the bi-gram, it should be a negative number +as we are subtracting the bi-gram's score during decoding. + +The decoding results obtained with the above command are shown below: + +.. code-block:: text + + $ For test-clean, WER of different settings are: + $ beam_size_4 2.61 best for test-clean + $ For test-other, WER of different settings are: + $ beam_size_4 6.74 best for test-other + +Recall that the lowest WER we obtained in :ref:`shallow_fusion` with beam size of 4 is ``2.77/7.08``, LODR +indeed **further improves** the WER. We can do even better if we increase ``--beam-size``: + +.. list-table:: WER of LODR with different beam sizes + :widths: 25 25 50 + :header-rows: 1 + + * - Beam size + - test-clean + - test-other + * - 4 + - 2.61 + - 6.74 + * - 8 + - 2.45 + - 6.38 + * - 12 + - 2.4 + - 6.23 \ No newline at end of file diff --git a/docs/source/decoding-with-langugage-models/index.rst b/docs/source/decoding-with-langugage-models/index.rst new file mode 100644 index 000000000..577ebbdfb --- /dev/null +++ b/docs/source/decoding-with-langugage-models/index.rst @@ -0,0 +1,12 @@ +Decoding with language models +============================= + +This section describes how to use external langugage models +during decoding to improve the WER of transducer models. + +.. toctree:: + :maxdepth: 2 + + shallow-fusion + LODR + rescoring diff --git a/docs/source/decoding-with-langugage-models/rescoring.rst b/docs/source/decoding-with-langugage-models/rescoring.rst new file mode 100644 index 000000000..d71acc1e5 --- /dev/null +++ b/docs/source/decoding-with-langugage-models/rescoring.rst @@ -0,0 +1,252 @@ +.. _rescoring: + +LM rescoring for Transducer +================================= + +LM rescoring is a commonly used approach to incorporate external LM information. Unlike shallow-fusion-based +methods (see :ref:`shallow-fusion`, :ref:`LODR`), rescoring is usually performed to re-rank the n-best hypotheses after beam search. +Rescoring is usually more efficient than shallow fusion since less computation is performed on the external LM. +In this tutorial, we will show you how to use external LM to rescore the n-best hypotheses decoded from neural transducer models in +`icefall `__. + +.. note:: + + This tutorial is based on the recipe + `pruned_transducer_stateless7_streaming `_, + which is a streaming transducer model trained on `LibriSpeech`_. + However, you can easily apply shallow fusion to other recipes. + If you encounter any problems, please open an issue `here `_. + +.. note:: + + For simplicity, the training and testing corpus in this tutorial is the same (`LibriSpeech`_). However, you can change the testing set + to any other domains (e.g `GigaSpeech`_) and use an external LM trained on that domain. + +.. HINT:: + + We recommend you to use a GPU for decoding. + +For illustration purpose, we will use a pre-trained ASR model from this `link `__. +If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`. + +As the initial step, let's download the pre-trained model. + +.. code-block:: bash + + $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + $ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp + $ git lfs pull --include "pretrained.pt" + $ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded + +As usual, we first test the model's performance without external LM. This can be done via the following command: + +.. code-block:: bash + + $ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/ + $ ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model False \ + --exp-dir $exp_dir \ + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search + +The following WERs are achieved on test-clean and test-other: + +.. code-block:: text + + $ For test-clean, WER of different settings are: + $ beam_size_4 3.11 best for test-clean + $ For test-other, WER of different settings are: + $ beam_size_4 7.93 best for test-other + +Now, we will try to improve the above WER numbers via external LM rescoring. We will download +a pre-trained LM from this `link `__. + +.. note:: + + This is an RNN LM trained on the LibriSpeech text corpus. So it might not be ideal for other corpus. + You may also train a RNN LM from scratch. Please refer to this `script `__ + for training a RNN LM and this `script `__ to train a transformer LM. + +.. code-block:: bash + + $ # download the external LM + $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + $ # create a symbolic link so that the checkpoint can be loaded + $ pushd icefall-librispeech-rnn-lm/exp + $ git lfs pull --include "pretrained.pt" + $ ln -s pretrained.pt epoch-99.pt + $ popd + + +With the RNNLM available, we can rescore the n-best hypotheses generated from `modified_beam_search`. Here, +`n` should be the number of beams, i.e ``--beam-size``. The command for LM rescoring is +as follows. Note that the ``--decoding-method`` is set to `modified_beam_search_lm_rescore` and ``--use-shallow-fusion`` +is set to `False`. + +.. code-block:: bash + + $ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp + $ lm_dir=./icefall-librispeech-rnn-lm/exp + $ lm_scale=0.43 + $ ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model False \ + --beam-size 4 \ + --exp-dir $exp_dir \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search_lm_rescore \ + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --use-shallow-fusion 0 \ + --lm-type rnn \ + --lm-exp-dir $lm_dir \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 3 \ + --lm-vocab-size 500 + +.. code-block:: text + + $ For test-clean, WER of different settings are: + $ beam_size_4 2.93 best for test-clean + $ For test-other, WER of different settings are: + $ beam_size_4 7.6 best for test-other + +Great! We made some improvements! Increasing the size of the n-best hypotheses will further boost the performance, +see the following table: + +.. list-table:: WERs of LM rescoring with different beam sizes + :widths: 25 25 25 + :header-rows: 1 + + * - Beam size + - test-clean + - test-other + * - 4 + - 2.93 + - 7.6 + * - 8 + - 2.67 + - 7.11 + * - 12 + - 2.59 + - 6.86 + +In fact, we can also apply LODR (see :ref:`LODR`) when doing LM rescoring. To do so, we need to +download the bi-gram required by LODR: + +.. code-block:: bash + + $ # download the bi-gram + $ git lfs install + $ git clone https://huggingface.co/marcoyang/librispeech_bigram + $ pushd data/lang_bpe_500 + $ ln -s ../../librispeech_bigram/2gram.arpa . + $ popd + +Then we can performn LM rescoring + LODR by changing the decoding method to `modified_beam_search_lm_rescore_LODR`. + +.. note:: + + This decoding method requires the dependency of `kenlm `_. You can install it + via this command: `pip install https://github.com/kpu/kenlm/archive/master.zip`. + +.. code-block:: bash + + $ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp + $ lm_dir=./icefall-librispeech-rnn-lm/exp + $ lm_scale=0.43 + $ ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model False \ + --beam-size 4 \ + --exp-dir $exp_dir \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search_lm_rescore_LODR \ + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --use-shallow-fusion 0 \ + --lm-type rnn \ + --lm-exp-dir $lm_dir \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 3 \ + --lm-vocab-size 500 + +You should see the following WERs after executing the commands above: + +.. code-block:: text + + $ For test-clean, WER of different settings are: + $ beam_size_4 2.9 best for test-clean + $ For test-other, WER of different settings are: + $ beam_size_4 7.57 best for test-other + +It's slightly better than LM rescoring. If we further increase the beam size, we will see +further improvements from LM rescoring + LODR: + +.. list-table:: WERs of LM rescoring + LODR with different beam sizes + :widths: 25 25 25 + :header-rows: 1 + + * - Beam size + - test-clean + - test-other + * - 4 + - 2.9 + - 7.57 + * - 8 + - 2.63 + - 7.04 + * - 12 + - 2.52 + - 6.73 + +As mentioned earlier, LM rescoring is usually faster than shallow-fusion based methods. +Here, we benchmark the WERs and decoding speed of them: + +.. list-table:: LM-rescoring-based methods vs shallow-fusion-based methods (The numbers in each field is WER on test-clean, WER on test-other and decoding time on test-clean) + :widths: 25 25 25 25 + :header-rows: 1 + + * - Decoding method + - beam=4 + - beam=8 + - beam=12 + * - `modified_beam_search` + - 3.11/7.93; 132s + - 3.1/7.95; 177s + - 3.1/7.96; 210s + * - `modified_beam_search_lm_shallow_fusion` + - 2.77/7.08; 262s + - 2.62/6.65; 352s + - 2.58/6.65; 488s + * - LODR + - 2.61/6.74; 400s + - 2.45/6.38; 610s + - 2.4/6.23; 870s + * - `modified_beam_search_lm_rescore` + - 2.93/7.6; 156s + - 2.67/7.11; 203s + - 2.59/6.86; 255s + * - `modified_beam_search_lm_rescore_LODR` + - 2.9/7.57; 160s + - 2.63/7.04; 203s + - 2.52/6.73; 263s + +.. note:: + + Decoding is performed with a single 32G V100, we set ``--max-duration`` to 600. + Decoding time here is only for reference and it may vary. \ No newline at end of file diff --git a/docs/source/decoding-with-langugage-models/shallow-fusion.rst b/docs/source/decoding-with-langugage-models/shallow-fusion.rst new file mode 100644 index 000000000..0d2837372 --- /dev/null +++ b/docs/source/decoding-with-langugage-models/shallow-fusion.rst @@ -0,0 +1,176 @@ +.. _shallow_fusion: + +Shallow fusion for Transducer +================================= + +External language models (LM) are commonly used to improve WERs for E2E ASR models. +This tutorial shows you how to perform ``shallow fusion`` with an external LM +to improve the word-error-rate of a transducer model. + +.. note:: + + This tutorial is based on the recipe + `pruned_transducer_stateless7_streaming `_, + which is a streaming transducer model trained on `LibriSpeech`_. + However, you can easily apply shallow fusion to other recipes. + If you encounter any problems, please open an issue here `icefall `_. + +.. note:: + + For simplicity, the training and testing corpus in this tutorial is the same (`LibriSpeech`_). However, you can change the testing set + to any other domains (e.g `GigaSpeech`_) and use an external LM trained on that domain. + +.. HINT:: + + We recommend you to use a GPU for decoding. + +For illustration purpose, we will use a pre-trained ASR model from this `link `__. +If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`. + +As the initial step, let's download the pre-trained model. + +.. code-block:: bash + + $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + $ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp + $ git lfs pull --include "pretrained.pt" + $ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded + +To test the model, let's have a look at the decoding results without using LM. This can be done via the following command: + +.. code-block:: bash + + $ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/ + $ ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model False \ + --exp-dir $exp_dir \ + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search + +The following WERs are achieved on test-clean and test-other: + +.. code-block:: text + + $ For test-clean, WER of different settings are: + $ beam_size_4 3.11 best for test-clean + $ For test-other, WER of different settings are: + $ beam_size_4 7.93 best for test-other + +These are already good numbers! But we can further improve it by using shallow fusion with external LM. +Training a language model usually takes a long time, we can download a pre-trained LM from this `link `__. + +.. code-block:: bash + + $ # download the external LM + $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + $ # create a symbolic link so that the checkpoint can be loaded + $ pushd icefall-librispeech-rnn-lm/exp + $ git lfs pull --include "pretrained.pt" + $ ln -s pretrained.pt epoch-99.pt + $ popd + +.. note:: + + This is an RNN LM trained on the LibriSpeech text corpus. So it might not be ideal for other corpus. + You may also train a RNN LM from scratch. Please refer to this `script `__ + for training a RNN LM and this `script `__ to train a transformer LM. + +To use shallow fusion for decoding, we can execute the following command: + +.. code-block:: bash + + $ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp + $ lm_dir=./icefall-librispeech-rnn-lm/exp + $ lm_scale=0.29 + $ ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model False \ + --beam-size 4 \ + --exp-dir $exp_dir \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir $lm_dir \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 3 \ + --lm-vocab-size 500 + +Note that we set ``--decoding-method modified_beam_search_lm_shallow_fusion`` and ``--use-shallow-fusion True`` +to use shallow fusion. ``--lm-type`` specifies the type of neural LM we are going to use, you can either choose +between ``rnn`` or ``transformer``. The following three arguments are associated with the rnn: + +- ``--rnn-lm-embedding-dim`` + The embedding dimension of the RNN LM + +- ``--rnn-lm-hidden-dim`` + The hidden dimension of the RNN LM + +- ``--rnn-lm-num-layers`` + The number of RNN layers in the RNN LM. + + +The decoding result obtained with the above command are shown below. + +.. code-block:: text + + $ For test-clean, WER of different settings are: + $ beam_size_4 2.77 best for test-clean + $ For test-other, WER of different settings are: + $ beam_size_4 7.08 best for test-other + +The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%. +A few parameters can be tuned to further boost the performance of shallow fusion: + +- ``--lm-scale`` + + Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large, + the LM score may dominant during decoding, leading to bad WER. A typical value of this is around 0.3. + +- ``--beam-size`` + + The number of active paths in the search beam. It controls the trade-off between decoding efficiency and accuracy. + +Here, we also show how `--beam-size` effect the WER and decoding time: + +.. list-table:: WERs and decoding time (on test-clean) of shallow fusion with different beam sizes + :widths: 25 25 25 25 + :header-rows: 1 + + * - Beam size + - test-clean + - test-other + - Decoding time on test-clean (s) + * - 4 + - 2.77 + - 7.08 + - 262 + * - 8 + - 2.62 + - 6.65 + - 352 + * - 12 + - 2.58 + - 6.65 + - 488 + +As we see, a larger beam size during shallow fusion improves the WER, but is also slower. + + + + + + + + diff --git a/docs/source/index.rst b/docs/source/index.rst index 8d76eb68b..a7d365a15 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -34,3 +34,8 @@ speech recognition recipes using `k2 `_. contributing/index huggingface/index + +.. toctree:: + :maxdepth: 2 + + decoding-with-langugage-models/index \ No newline at end of file diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst index ea9f350cd..2e8d0893a 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst @@ -1,7 +1,7 @@ Distillation with HuBERT ======================== -This tutorial shows you how to perform knowledge distillation in `icefall`_ +This tutorial shows you how to perform knowledge distillation in `icefall `_ with the `LibriSpeech`_ dataset. The distillation method used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD). Please have a look at our paper `Predicting Multi-Codebook Vector Quantization Indexes for Knowledge Distillation `_ @@ -13,7 +13,7 @@ for more details about MVQ-KD. `pruned_transducer_stateless4 `_. Currently, we only implement MVQ-KD in this recipe. However, MVQ-KD is theoretically applicable to all recipes with only minor changes needed. Feel free to try out MVQ-KD in different recipes. If you - encounter any problems, please open an issue here `icefall `_. + encounter any problems, please open an issue here `icefall `__. .. note:: @@ -217,7 +217,7 @@ the following command. --exp-dir $exp_dir \ --enable-distillation True -You should get similar results as `here `_. +You should get similar results as `here `__. That's all! Feel free to experiment with your own setups and report your results. -If you encounter any problems during training, please open up an issue `here `_. +If you encounter any problems during training, please open up an issue `here `__. diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst index 42fd3df77..1bc1dd984 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -8,10 +8,10 @@ with the `LibriSpeech `_ dataset. .. Note:: - The tutorial is suitable for `pruned_transducer_stateless `_, - `pruned_transducer_stateless2 `_, - `pruned_transducer_stateless4 `_, - `pruned_transducer_stateless5 `_, + The tutorial is suitable for `pruned_transducer_stateless `__, + `pruned_transducer_stateless2 `__, + `pruned_transducer_stateless4 `__, + `pruned_transducer_stateless5 `__, We will take pruned_transducer_stateless4 as an example in this tutorial. .. HINT:: @@ -237,7 +237,7 @@ them, please modify ``./pruned_transducer_stateless4/train.py`` directly. .. NOTE:: - The options for `pruned_transducer_stateless5 `_ are a little different from + The options for `pruned_transducer_stateless5 `__ are a little different from other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5. @@ -529,13 +529,13 @@ Download pretrained models If you don't want to train from scratch, you can download the pretrained models by visiting the following links: - - `pruned_transducer_stateless `_ + - `pruned_transducer_stateless `__ - - `pruned_transducer_stateless2 `_ + - `pruned_transducer_stateless2 `__ - - `pruned_transducer_stateless4 `_ + - `pruned_transducer_stateless4 `__ - - `pruned_transducer_stateless5 `_ + - `pruned_transducer_stateless5 `__ See ``_ for the details of the above pretrained models diff --git a/docs/source/recipes/Streaming-ASR/introduction.rst b/docs/source/recipes/Streaming-ASR/introduction.rst index e1382e77d..ac77a51d1 100644 --- a/docs/source/recipes/Streaming-ASR/introduction.rst +++ b/docs/source/recipes/Streaming-ASR/introduction.rst @@ -45,9 +45,9 @@ the input features. We have three variants of Emformer models in ``icefall``. - - ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe `_. + - ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe `__. - ``conv_emformer_transducer_stateless`` using ConvEmformer implemented by ourself. Different from the Emformer in torchaudio, ConvEmformer has a convolution in each layer and uses the mechanisms in our reworked conformer model. - See `LibriSpeech recipe `_. + See `LibriSpeech recipe `__. - ``conv_emformer_transducer_stateless2`` using ConvEmformer implemented by ourself. The only difference from the above one is that it uses a simplified memory bank. See `LibriSpeech recipe `_. diff --git a/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst index de7102ba8..2ca70bcf3 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -6,10 +6,10 @@ with the `LibriSpeech `_ dataset. .. Note:: - The tutorial is suitable for `pruned_transducer_stateless `_, - `pruned_transducer_stateless2 `_, - `pruned_transducer_stateless4 `_, - `pruned_transducer_stateless5 `_, + The tutorial is suitable for `pruned_transducer_stateless `__, + `pruned_transducer_stateless2 `__, + `pruned_transducer_stateless4 `__, + `pruned_transducer_stateless5 `__, We will take pruned_transducer_stateless4 as an example in this tutorial. .. HINT:: @@ -264,7 +264,7 @@ them, please modify ``./pruned_transducer_stateless4/train.py`` directly. .. NOTE:: - The options for `pruned_transducer_stateless5 `_ are a little different from + The options for `pruned_transducer_stateless5 `__ are a little different from other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5. diff --git a/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst index f0e8961d7..8b75473c6 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst @@ -6,7 +6,7 @@ with the `LibriSpeech `_ dataset. .. Note:: - The tutorial is suitable for `pruned_transducer_stateless7_streaming `_, + The tutorial is suitable for `pruned_transducer_stateless7_streaming `__, .. HINT:: @@ -642,7 +642,7 @@ Download pretrained models If you don't want to train from scratch, you can download the pretrained models by visiting the following links: - - `pruned_transducer_stateless7_streaming `_ + - `pruned_transducer_stateless7_streaming `__ See ``_ for the details of the above pretrained models From ffe816e2a8314318a4ef6d5eaba34b62b842ba3f Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 6 Jul 2023 23:12:41 +0800 Subject: [PATCH 002/113] Fix blank skip ci test (#1167) * Fix for ci * Fix frame_reducer --- ...ned-transducer-stateless7-ctc-bs-2023-01-29.sh} | 2 +- ...n-librispeech-2023-01-29-stateless7-ctc-bs.yml} | 8 ++++---- .../frame_reducer.py | 14 +++++++------- 3 files changed, 12 insertions(+), 12 deletions(-) rename .github/scripts/{run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh => run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh} (100%) rename .github/workflows/{run-librispeech-2022-12-15-stateless7-ctc-bs.yml => run-librispeech-2023-01-29-stateless7-ctc-bs.yml} (97%) diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh similarity index 100% rename from .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh rename to .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh index 761eb72e2..7d2853c17 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh @@ -21,9 +21,9 @@ tree $repo/ ls -lh $repo/test_wavs/*.wav pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/HLG.pt" git lfs pull --include "data/lang_bpe_500/L.pt" git lfs pull --include "data/lang_bpe_500/LG.pt" +git lfs pull --include "data/lang_bpe_500/HLG.pt" git lfs pull --include "data/lang_bpe_500/Linv.pt" git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/cpu_jit.pt" diff --git a/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml b/.github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml similarity index 97% rename from .github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml rename to .github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml index 40a742988..821abc25d 100644 --- a/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml +++ b/.github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: run-librispeech-2022-12-15-stateless7-ctc-bs +name: run-librispeech-2023-01-29-stateless7-ctc-bs # zipformer on: @@ -34,7 +34,7 @@ on: - cron: "50 15 * * *" jobs: - run_librispeech_2022_12_15_zipformer_ctc_bs: + run_librispeech_2023_01_29_zipformer_ctc_bs: if: github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: @@ -124,7 +124,7 @@ jobs: export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh + .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc_bs if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' @@ -159,5 +159,5 @@ jobs: uses: actions/upload-artifact@v2 if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-bs-2022-12-15 + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-bs-2023-01-29 path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/exp/ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index 0841f7cf1..c44cb1eaf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -81,20 +81,20 @@ class FrameReducer(nn.Module): fake_limit_indexes = torch.topk( ctc_output[:, :, blank_id], max_limit_len ).indices - T = ( + T_arange = ( torch.arange(max_limit_len) .expand_as( fake_limit_indexes, ) .to(device=x.device) ) - T = torch.remainder(T, limit_lens.unsqueeze(1)) - limit_indexes = torch.gather(fake_limit_indexes, 1, T) + T_arange = torch.remainder(T_arange, limit_lens.unsqueeze(1)) + limit_indexes = torch.gather(fake_limit_indexes, 1, T_arange) limit_mask = torch.full_like( non_blank_mask, - False, + 0, device=x.device, - ).scatter_(1, limit_indexes, True) + ).scatter_(1, limit_indexes, 1) non_blank_mask = non_blank_mask | ~limit_mask @@ -108,9 +108,9 @@ class FrameReducer(nn.Module): ) - out_lens ) - max_pad_len = pad_lens_list.max() + max_pad_len = int(pad_lens_list.max()) - out = F.pad(x, (0, 0, 0, max_pad_len)) + out = F.pad(x, [0, 0, 0, max_pad_len]) valid_pad_mask = ~make_pad_mask(pad_lens_list) total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1) From 41b16d783878fe3de304bb70285d97581e629eb5 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Sat, 8 Jul 2023 17:01:51 +0200 Subject: [PATCH 003/113] SURT recipe for AMI and ICSI (#1133) * merge upstream * add SURT model and training * add libricss decoding * add chunk width randomization * decode SURT with libricss * initial commit for zipformer_ctc * remove unwanted changes * remove changes to other recipe * fix zipformer softlink * fix for JIT export * add missing file * fix symbolic links * update results * clean commit for SURT recipe * training libricss surt model * remove unwanted files * remove unwanted changes * remove changes in librispeech * change some files to symlinks * remove unwanted changes in utils * add export script * add README * minor fix in README * add assets for README * replace some files with symlinks * remove unused decoding methods * initial commit for SURT AMI recipe * fix symlink * add train + decode scripts * add missing symlink * change files to symlink * change file type --- egs/ami/SURT/README.md | 156 ++ .../SURT/dprnn_zipformer/asr_datamodule.py | 399 +++++ egs/ami/SURT/dprnn_zipformer/beam_search.py | 1 + egs/ami/SURT/dprnn_zipformer/decode.py | 622 ++++++++ egs/ami/SURT/dprnn_zipformer/decoder.py | 1 + egs/ami/SURT/dprnn_zipformer/dprnn.py | 1 + .../SURT/dprnn_zipformer/encoder_interface.py | 1 + egs/ami/SURT/dprnn_zipformer/export.py | 1 + egs/ami/SURT/dprnn_zipformer/joiner.py | 1 + egs/ami/SURT/dprnn_zipformer/model.py | 1 + egs/ami/SURT/dprnn_zipformer/optim.py | 1 + egs/ami/SURT/dprnn_zipformer/scaling.py | 1 + .../SURT/dprnn_zipformer/scaling_converter.py | 1 + egs/ami/SURT/dprnn_zipformer/test_model.py | 1 + egs/ami/SURT/dprnn_zipformer/train.py | 1420 +++++++++++++++++ egs/ami/SURT/dprnn_zipformer/train_adapt.py | 1411 ++++++++++++++++ egs/ami/SURT/dprnn_zipformer/zipformer.py | 1 + egs/ami/SURT/local/add_source_feats.py | 78 + egs/ami/SURT/local/compute_fbank_aimix.py | 185 +++ egs/ami/SURT/local/compute_fbank_ami.py | 94 ++ egs/ami/SURT/local/compute_fbank_icsi.py | 95 ++ egs/ami/SURT/local/compute_fbank_ihm.py | 101 ++ egs/ami/SURT/local/prepare_ami_train_cuts.py | 146 ++ egs/ami/SURT/local/prepare_icsi_train_cuts.py | 67 + egs/ami/SURT/local/prepare_lang_bpe.py | 1 + egs/ami/SURT/local/train_bpe_model.py | 1 + egs/ami/SURT/prepare.sh | 195 +++ egs/ami/SURT/shared | 1 + 28 files changed, 4984 insertions(+) create mode 100644 egs/ami/SURT/README.md create mode 100644 egs/ami/SURT/dprnn_zipformer/asr_datamodule.py create mode 120000 egs/ami/SURT/dprnn_zipformer/beam_search.py create mode 100755 egs/ami/SURT/dprnn_zipformer/decode.py create mode 120000 egs/ami/SURT/dprnn_zipformer/decoder.py create mode 120000 egs/ami/SURT/dprnn_zipformer/dprnn.py create mode 120000 egs/ami/SURT/dprnn_zipformer/encoder_interface.py create mode 120000 egs/ami/SURT/dprnn_zipformer/export.py create mode 120000 egs/ami/SURT/dprnn_zipformer/joiner.py create mode 120000 egs/ami/SURT/dprnn_zipformer/model.py create mode 120000 egs/ami/SURT/dprnn_zipformer/optim.py create mode 120000 egs/ami/SURT/dprnn_zipformer/scaling.py create mode 120000 egs/ami/SURT/dprnn_zipformer/scaling_converter.py create mode 120000 egs/ami/SURT/dprnn_zipformer/test_model.py create mode 100755 egs/ami/SURT/dprnn_zipformer/train.py create mode 100755 egs/ami/SURT/dprnn_zipformer/train_adapt.py create mode 120000 egs/ami/SURT/dprnn_zipformer/zipformer.py create mode 100755 egs/ami/SURT/local/add_source_feats.py create mode 100755 egs/ami/SURT/local/compute_fbank_aimix.py create mode 100755 egs/ami/SURT/local/compute_fbank_ami.py create mode 100755 egs/ami/SURT/local/compute_fbank_icsi.py create mode 100755 egs/ami/SURT/local/compute_fbank_ihm.py create mode 100755 egs/ami/SURT/local/prepare_ami_train_cuts.py create mode 100755 egs/ami/SURT/local/prepare_icsi_train_cuts.py create mode 120000 egs/ami/SURT/local/prepare_lang_bpe.py create mode 120000 egs/ami/SURT/local/train_bpe_model.py create mode 100755 egs/ami/SURT/prepare.sh create mode 120000 egs/ami/SURT/shared diff --git a/egs/ami/SURT/README.md b/egs/ami/SURT/README.md new file mode 100644 index 000000000..74a8ba014 --- /dev/null +++ b/egs/ami/SURT/README.md @@ -0,0 +1,156 @@ +# Introduction + +This is a multi-talker ASR recipe for the AMI and ICSI datasets. We train a Streaming +Unmixing and Recognition Transducer (SURT) model for the task. + +Please refer to the `egs/libricss/SURT` recipe README for details about the task and the +model. + +## Description of the recipe + +### Pre-requisites + +The recipes in this directory need the following packages to be installed: + +- [meeteval](https://github.com/fgnt/meeteval) +- [einops](https://github.com/arogozhnikov/einops) + +Additionally, we initialize the model with the pre-trained model from the LibriCSS recipe. +Please download this checkpoint (see below) or train the LibriCSS recipe first. + +### Training + +To train the model, run the following from within `egs/ami/SURT`: + +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +python dprnn_zipformer/train.py \ + --use-fp16 True \ + --exp-dir dprnn_zipformer/exp/surt_base \ + --world-size 4 \ + --max-duration 500 \ + --max-duration-valid 250 \ + --max-cuts 200 \ + --num-buckets 50 \ + --num-epochs 30 \ + --enable-spec-aug True \ + --enable-musan False \ + --ctc-loss-scale 0.2 \ + --heat-loss-scale 0.2 \ + --base-lr 0.004 \ + --model-init-ckpt exp/libricss_base.pt \ + --chunk-width-randomization True \ + --num-mask-encoder-layers 4 \ + --num-encoder-layers 2,2,2,2,2 +``` + +The above is for SURT-base (~26M). For SURT-large (~38M), use: + +```bash + --model-init-ckpt exp/libricss_large.pt \ + --num-mask-encoder-layers 6 \ + --num-encoder-layers 2,4,3,2,4 \ + --model-init-ckpt exp/zipformer_large.pt \ +``` + +**NOTE:** You may need to decrease the `--max-duration` for SURT-large to avoid OOM. + +### Adaptation + +The training step above only trains on simulated mixtures. For best results, we also +adapt the final model on the AMI+ICSI train set. For this, run the following from within +`egs/ami/SURT`: + +```bash +export CUDA_VISIBLE_DEVICES="0" + +python dprnn_zipformer/train_adapt.py \ + --use-fp16 True \ + --exp-dir dprnn_zipformer/exp/surt_base_adapt \ + --world-size 4 \ + --max-duration 500 \ + --max-duration-valid 250 \ + --max-cuts 200 \ + --num-buckets 50 \ + --num-epochs 8 \ + --lr-epochs 2 \ + --enable-spec-aug True \ + --enable-musan False \ + --ctc-loss-scale 0.2 \ + --base-lr 0.0004 \ + --model-init-ckpt dprnn_zipformer/exp/surt_base/epoch-30.pt \ + --chunk-width-randomization True \ + --num-mask-encoder-layers 4 \ + --num-encoder-layers 2,2,2,2,2 +``` + +For SURT-large, use the following config: + +```bash + --num-mask-encoder-layers 6 \ + --num-encoder-layers 2,4,3,2,4 \ + --model-init-ckpt dprnn_zipformer/exp/surt_large/epoch-30.pt \ + --num-epochs 15 \ + --lr-epochs 4 \ +``` + + +### Decoding + +To decode the model, run the following from within `egs/ami/SURT`: + +#### Greedy search + +```bash +export CUDA_VISIBLE_DEVICES="0" + +python dprnn_zipformer/decode.py \ + --epoch 20 --avg 1 --use-averaged-model False \ + --exp-dir dprnn_zipformer/exp/surt_base_adapt \ + --max-duration 250 \ + --decoding-method greedy_search +``` + +#### Beam search + +```bash +python dprnn_zipformer/decode.py \ + --epoch 20 --avg 1 --use-averaged-model False \ + --exp-dir dprnn_zipformer/exp/surt_base_adapt \ + --max-duration 250 \ + --decoding-method modified_beam_search \ + --beam-size 4 +``` + +## Results (using beam search) + +**AMI** + +| Model | IHM-Mix | SDM | MDM | +|------------|:-------:|:----:|:----:| +| SURT-base | 39.8 | 65.4 | 46.6 | +| + adapt | 37.4 | 46.9 | 43.7 | +| SURT-large | 36.8 | 62.5 | 44.4 | +| + adapt | **35.1** | **44.6** | **41.4** | + +**ICSI** + +| Model | IHM-Mix | SDM | +|------------|:-------:|:----:| +| SURT-base | 28.3 | 60.0 | +| + adapt | 26.3 | 33.9 | +| SURT-large | 27.8 | 59.7 | +| + adapt | **24.4** | **32.3** | + +## Pre-trained models and logs + +* LibriCSS pre-trained model (for initialization): [base](https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer/tree/main/exp/surt_base) [large](https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer/tree/main/exp/surt_large) + +* Pre-trained models: + +* Training logs: + - surt_base: + - surt_base_adapt: + - surt_large: + - surt_large_adapt: diff --git a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py new file mode 100644 index 000000000..ec8106bc3 --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py @@ -0,0 +1,399 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutMix, + DynamicBucketingSampler, + K2SurtDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class AmiAsrDataModule: + """ + DataModule for k2 SURT experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/manifests"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--max-duration-valid", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--max-cuts", + type=int, + default=100, + help="Maximum number of cuts in a single batch. You can " + "reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + sources: bool = False, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SurtDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + return_sources=sources, + strict=False, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + quadratic_duration=30.0, + max_cuts=self.args.max_cuts, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + max_cuts=self.args.max_cuts, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + + logging.info("About to create dev dataset") + validate = K2SurtDataset( + input_strategy=OnTheFlyFeatures( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + ) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + return_sources=False, + strict=False, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration_valid, + quadratic_duration=30.0, + max_cuts=self.args.max_cuts, + shuffle=False, + ) + logging.info("About to create dev dataloader") + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SurtDataset( + input_strategy=OnTheFlyFeatures( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + ) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + return_sources=False, + strict=False, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration_valid, + max_cuts=self.args.max_cuts, + shuffle=False, + ) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + return test_dl + + @lru_cache() + def aimix_train_cuts( + self, + rvb_affix: str = "clean", + sources: bool = True, + ) -> CutSet: + logging.info("About to get train cuts") + source_affix = "_sources" if sources else "" + cs = load_manifest_lazy( + self.args.manifest_dir / f"cuts_train_{rvb_affix}{source_affix}.jsonl.gz" + ) + cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0) + return cs + + @lru_cache() + def train_cuts( + self, + ) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "cuts_train_ami_icsi.jsonl.gz" + ) + + @lru_cache() + def ami_cuts(self, split: str = "dev", type: str = "sdm") -> CutSet: + logging.info(f"About to get AMI {split} {type} cuts") + return load_manifest_lazy( + self.args.manifest_dir / f"cuts_ami-{type}_{split}.jsonl.gz" + ) + + @lru_cache() + def icsi_cuts(self, split: str = "dev", type: str = "sdm") -> CutSet: + logging.info(f"About to get ICSI {split} {type} cuts") + return load_manifest_lazy( + self.args.manifest_dir / f"cuts_icsi-{type}_{split}.jsonl.gz" + ) diff --git a/egs/ami/SURT/dprnn_zipformer/beam_search.py b/egs/ami/SURT/dprnn_zipformer/beam_search.py new file mode 120000 index 000000000..581b29833 --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/beam_search.py @@ -0,0 +1 @@ +../../../libricss/SURT/dprnn_zipformer/beam_search.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/decode.py b/egs/ami/SURT/dprnn_zipformer/decode.py new file mode 100755 index 000000000..d1a1eddc9 --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/decode.py @@ -0,0 +1,622 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./dprnn_zipformer/decode.py \ + --epoch 20 \ + --avg 1 \ + --use-averaged-model false \ + --exp-dir ./dprnn_zipformer/exp_adapt \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./dprnn_zipformer/decode.py \ + --epoch 20 \ + --avg 1 \ + --use-averaged-model false \ + --exp-dir ./dprnn_zipformer/exp_adapt \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./dprnn_zipformer/decode.py \ + --epoch 20 \ + --avg 1 \ + --use-averaged-model false \ + --exp-dir ./dprnn_zipformer/exp_adapt \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AmiAsrDataModule +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.utils import EPSILON +from train import add_model_arguments, get_params, get_surt_model + +from icefall import LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_surt_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="dprnn_zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + feature_lens = batch["input_lens"].to(device) + + # Apply the mask encoder + B, T, F = feature.shape + processed = model.mask_encoder(feature) # B,T,F*num_channels + masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1) + x_masked = [feature * m for m in masks] + + # Recognition + # Stack the inputs along the batch axis + h = torch.cat(x_masked, dim=0) + h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0) + encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens) + + if model.joint_encoder_layer is not None: + encoder_out = model.joint_encoder_layer(encoder_out) + + def _group_channels(hyps: List[str]) -> List[List[str]]: + """ + Currently we have a batch of size M*B, where M is the number of + channels and B is the batch size. We need to group the hypotheses + into B groups, each of which contains M hypotheses. + + Example: + hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2'] + _group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']] + """ + assert len(hyps) == B * params.num_channels + out_hyps = [] + for i in range(B): + out_hyps.append(hyps[i::B]) + return out_hyps + + hyps = [] + if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp)) + + if params.decoding_method == "greedy_search": + return {"greedy_search": _group_channels(hyps)} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: _group_channels(hyps)} + else: + return {f"beam_size_{params.beam_size}": _group_channels(hyps)} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + cut_ids = [cut.id for cut in batch["cuts"]] + cuts_batch = batch["cuts"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + for cut_id, hyp_words in zip(cut_ids, hyps): + # Reference is a list of supervision texts sorted by start time. + ref_words = [ + s.text.strip() + for s in sorted( + cuts_batch[cut_id].supervisions, key=lambda s: s.start + ) + ] + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(cut_ids) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_surt_error_stats( + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + num_channels=params.num_channels, + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LmScorer.add_arguments(parser) + AmiAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + ), f"Decoding method {params.decoding_method} is not supported." + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_surt_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + ami = AmiAsrDataModule(args) + + # NOTE(@desh2608): we filter segments longer than 120s to avoid OOM errors in decoding. + # However, 99.9% of the segments are shorter than 120s, so this should not + # substantially affect the results. In future, we will implement an overlapped + # inference method to avoid OOM errors. + + test_sets = {} + for split in ["dev", "test"]: + for type in ["ihm-mix", "sdm", "mdm8-bf"]: + test_sets[f"ami-{split}_{type}"] = ( + ami.ami_cuts(split=split, type=type) + .trim_to_supervision_groups(max_pause=0.0) + .filter(lambda c: 0.1 < c.duration < 120.0) + .to_eager() + ) + + for split in ["dev", "test"]: + for type in ["ihm-mix", "sdm"]: + test_sets[f"icsi-{split}_{type}"] = ( + ami.icsi_cuts(split=split, type=type) + .trim_to_supervision_groups(max_pause=0.0) + .filter(lambda c: 0.1 < c.duration < 120.0) + .to_eager() + ) + + for test_set, test_cuts in test_sets.items(): + test_dl = ami.test_dataloaders(test_cuts) + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/ami/SURT/dprnn_zipformer/decoder.py b/egs/ami/SURT/dprnn_zipformer/decoder.py new file mode 120000 index 000000000..c34865c25 --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/decoder.py @@ -0,0 +1 @@ +../../../libricss/SURT/dprnn_zipformer/decoder.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/dprnn.py b/egs/ami/SURT/dprnn_zipformer/dprnn.py new file mode 120000 index 000000000..8918beb32 --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/dprnn.py @@ -0,0 +1 @@ +../../../libricss/SURT/dprnn_zipformer/dprnn.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/encoder_interface.py b/egs/ami/SURT/dprnn_zipformer/encoder_interface.py new file mode 120000 index 000000000..0ba945d0f --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../libricss/SURT/dprnn_zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/export.py b/egs/ami/SURT/dprnn_zipformer/export.py new file mode 120000 index 000000000..3deae4471 --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/export.py @@ -0,0 +1 @@ +../../../libricss/SURT/dprnn_zipformer/export.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/joiner.py b/egs/ami/SURT/dprnn_zipformer/joiner.py new file mode 120000 index 000000000..79fbe8769 --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/joiner.py @@ -0,0 +1 @@ +../../../libricss/SURT/dprnn_zipformer/joiner.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/model.py b/egs/ami/SURT/dprnn_zipformer/model.py new file mode 120000 index 000000000..ae8c65c99 --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/model.py @@ -0,0 +1 @@ +../../../libricss/SURT/dprnn_zipformer/model.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/optim.py b/egs/ami/SURT/dprnn_zipformer/optim.py new file mode 120000 index 000000000..366d0f7a2 --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/optim.py @@ -0,0 +1 @@ +../../../libricss/SURT/dprnn_zipformer/optim.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/scaling.py b/egs/ami/SURT/dprnn_zipformer/scaling.py new file mode 120000 index 000000000..f11d49d77 --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/scaling.py @@ -0,0 +1 @@ +../../../libricss/SURT/dprnn_zipformer/scaling.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/scaling_converter.py b/egs/ami/SURT/dprnn_zipformer/scaling_converter.py new file mode 120000 index 000000000..1533cbe0e --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../libricss/SURT/dprnn_zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/test_model.py b/egs/ami/SURT/dprnn_zipformer/test_model.py new file mode 120000 index 000000000..1259849e0 --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/test_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py \ No newline at end of file diff --git a/egs/ami/SURT/dprnn_zipformer/train.py b/egs/ami/SURT/dprnn_zipformer/train.py new file mode 100755 index 000000000..cd5fafc34 --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/train.py @@ -0,0 +1,1420 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +cd egs/ami/SURT/ +./prepare.sh + +./dprnn_zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir dprnn_zipformer/exp \ + --max-duration 650 +""" + +import argparse +import copy +import logging +import warnings +from itertools import chain +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AmiAsrDataModule +from decoder import Decoder +from dprnn import DPRNN +from einops.layers.torch import Rearrange +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import LOG_EPSILON, fix_random_seed +from model import SURT +from optim import Eden, ScaledAdam +from scaling import ScaledLinear, ScaledLSTM +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-mask-encoder-layers", + type=int, + default=4, + help="Number of layers in the DPRNN based mask encoder.", + ) + + parser.add_argument( + "--mask-encoder-dim", + type=int, + default=256, + help="Hidden dimension of the LSTM blocks in DPRNN.", + ) + + parser.add_argument( + "--mask-encoder-segment-size", + type=int, + default=32, + help="Segment size of the SegLSTM in DPRNN. Ideally, this should be equal to the " + "decode-chunk-length of the zipformer encoder.", + ) + + parser.add_argument( + "--chunk-width-randomization", + type=bool, + default=False, + help="Whether to randomize the chunk width in DPRNN.", + ) + + # Zipformer config is based on: + # https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740 + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,2,2,2", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="768,768,768,768,768", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="256,256,256,256,256", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="192,192,192,192,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--use-joint-encoder-layer", + type=str, + default="lstm", + choices=["linear", "lstm", "none"], + help="Whether to use a joint layer to combine all branches.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conv_lstm_transducer_stateless_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--model-init-ckpt", + type=str, + default=None, + help="""The model checkpoint to initialize the model (either full or part). + If not specified, the model is randomly initialized. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.004, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--heat-loss-scale", + type=float, + default=0.2, + help="Scale for HEAT loss on separated sources.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=1, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 2000, + # parameters for SURT + "num_channels": 2, + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed + # parameters for Noam + "model_warm_step": 5000, # arg given to model, not for lrate + # parameters for ctc loss + "beam_size": 10, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + + return params + + +def get_mask_encoder_model(params: AttributeDict) -> nn.Module: + mask_encoder = DPRNN( + feature_dim=params.feature_dim, + input_size=params.mask_encoder_dim, + hidden_size=params.mask_encoder_dim, + output_size=params.feature_dim * params.num_channels, + segment_size=params.mask_encoder_segment_size, + num_blocks=params.num_mask_encoder_layers, + chunk_width_randomization=params.chunk_width_randomization, + ) + return mask_encoder + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_joint_encoder_layer(params: AttributeDict) -> nn.Module: + class TakeFirst(nn.Module): + def forward(self, x): + return x[0] + + if params.use_joint_encoder_layer == "linear": + encoder_dim = int(params.encoder_dims.split(",")[-1]) + joint_layer = nn.Sequential( + Rearrange("(c b) t d -> b t (c d)", c=params.num_channels), + nn.Linear( + params.num_channels * encoder_dim, params.num_channels * encoder_dim + ), + nn.ReLU(), + Rearrange("b t (c d) -> (c b) t d", c=params.num_channels), + ) + elif params.use_joint_encoder_layer == "lstm": + encoder_dim = int(params.encoder_dims.split(",")[-1]) + joint_layer = nn.Sequential( + Rearrange("(c b) t d -> b t (c d)", c=params.num_channels), + ScaledLSTM( + input_size=params.num_channels * encoder_dim, + hidden_size=params.num_channels * encoder_dim, + num_layers=1, + bias=True, + batch_first=True, + dropout=0.0, + bidirectional=False, + ), + TakeFirst(), + nn.ReLU(), + Rearrange("b t (c d) -> (c b) t d", c=params.num_channels), + ) + elif params.use_joint_encoder_layer == "none": + joint_layer = None + else: + raise ValueError( + f"Unknown joint encoder layer type: {params.use_joint_encoder_layer}" + ) + return joint_layer + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_surt_model( + params: AttributeDict, +) -> nn.Module: + mask_encoder = get_mask_encoder_model(params) + encoder = get_encoder_model(params) + joint_layer = get_joint_encoder_layer(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = SURT( + mask_encoder=mask_encoder, + encoder=encoder, + joint_encoder_layer=joint_layer, + decoder=decoder, + joiner=joiner, + num_channels=params.num_channels, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_heat_loss(x_masked, batch, num_channels=2) -> Tensor: + """ + Compute HEAT loss for separated sources using the output of mask encoder. + Args: + x_masked: + The output of mask encoder. It is a tensor of shape (B, T, C). + batch: + A batch of data. See `lhotse.dataset.K2SurtDatasetWithSources()` + for the content in it. + num_channels: + The number of output branches in the SURT model. + """ + B, T, D = x_masked[0].shape + device = x_masked[0].device + + # Create training targets for each channel. + targets = [] + for i in range(num_channels): + target = torch.ones_like(x_masked[i]) * LOG_EPSILON + targets.append(target) + + source_feats = batch["source_feats"] + source_boundaries = batch["source_boundaries"] + input_lens = batch["input_lens"].to(device) + # Assign sources to channels based on the HEAT criteria + for b in range(B): + cut_source_feats = source_feats[b] + cut_source_boundaries = source_boundaries[b] + last_seg_end = [0 for _ in range(num_channels)] + for source_feat, (start, end) in zip(cut_source_feats, cut_source_boundaries): + assigned = False + end = min(end, T) + source_feat = source_feat[: end - start, :] + for i in range(num_channels): + if start >= last_seg_end[i]: + targets[i][b, start:end, :] += source_feat.to(device) + last_seg_end[i] = max(end, last_seg_end[i]) + assigned = True + break + if not assigned: + min_end_channel = last_seg_end.index(min(last_seg_end)) + targets[min_end_channel][b, start:end, :] += source_feat.to(device) + last_seg_end[min_end_channel] = max(end, last_seg_end[min_end_channel]) + + # Get padding mask based on input lengths + pad_mask = torch.arange(T, device=device).expand(B, T) > input_lens.unsqueeze(1) + pad_mask = pad_mask.unsqueeze(-1) + + # Compute masked loss for each channel + losses = torch.zeros((num_channels, B, T, D), device=device) + for i in range(num_channels): + loss = nn.functional.mse_loss(x_masked[i], targets[i], reduction="none") + # Apply padding mask to loss + loss.masked_fill_(pad_mask, 0) + losses[i] = loss + + # loss: C x B x T x D. pad_mask: B x T x 1 + # We want to compute loss for each item in the batch. Each item has loss given + # by the sum over C, and average over T and D. For T, we need to use the padding. + loss = losses.sum(0).mean(-1).sum(-1) / batch["input_lens"].to(device) + return loss + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"].to(device) + feature_lens = batch["input_lens"].to(device) + + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + + # The dataloader returns text as a list of cuts, each of which is a list of channel + # text. We flatten this to a list where all channels are together, i.e., it looks like + # [utt1_ch1, utt2_ch1, ..., uttN_ch1, utt1_ch2, ...., uttN,ch2]. + text = [val for tup in zip(*batch["text"]) for val in tup] + assert len(text) == len(feature) * params.num_channels + + # Convert all channel texts to token IDs and create a ragged tensor. + y = sp.encode(text, out_type=int) + y = k2.RaggedTensor(y).to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.model_warm_step + + with torch.set_grad_enabled(is_training): + (simple_loss, pruned_loss, ctc_loss, x_masked) = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + reduction="none", + subsampling_factor=params.subsampling_factor, + ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + ctc_loss_is_finite = torch.isfinite(ctc_loss) + + # Compute HEAT loss + if is_training and params.heat_loss_scale > 0.0: + heat_loss = compute_heat_loss( + x_masked, batch, num_channels=params.num_channels + ) + else: + heat_loss = torch.tensor(0.0, device=device) + + heat_loss_is_finite = torch.isfinite(heat_loss) + is_finite = ( + simple_loss_is_finite + & pruned_loss_is_finite + & ctc_loss_is_finite + & heat_loss_is_finite + ) + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_losses: {simple_loss}\n" + f"pruned_losses: {pruned_loss}\n" + f"ctc_losses: {ctc_loss}\n" + f"heat_losses: {heat_loss}\n" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + ctc_loss = ctc_loss[ctc_loss_is_finite] + heat_loss = heat_loss[heat_loss_is_finite] + + # If either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if ( + torch.all(~simple_loss_is_finite) + or torch.all(~pruned_loss_is_finite) + or torch.all(~ctc_loss_is_finite) + or torch.all(~heat_loss_is_finite) + ): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss_sum = simple_loss.sum() + pruned_loss_sum = pruned_loss.sum() + ctc_loss_sum = ctc_loss.sum() + heat_loss_sum = heat_loss.sum() + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss = ( + simple_loss_scale * simple_loss_sum + + pruned_loss_scale * pruned_loss_sum + + params.ctc_loss_scale * ctc_loss_sum + + params.heat_loss_scale * heat_loss_sum + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss_sum.detach().cpu().item() + info["pruned_loss"] = pruned_loss_sum.detach().cpu().item() + if params.ctc_loss_scale > 0.0: + info["ctc_loss"] = ctc_loss_sum.detach().cpu().item() + if params.heat_loss_scale > 0.0: + info["heat_loss"] = heat_loss_sum.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + torch.cuda.empty_cache() + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = batch["inputs"].shape[0] + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_surt_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + + if checkpoints is None and params.model_init_ckpt is not None: + logging.info( + f"Initializing model with checkpoint from {params.model_init_ckpt}" + ) + init_ckpt = torch.load(params.model_init_ckpt, map_location=device) + model.load_state_dict(init_ckpt["model"], strict=False) + + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + ami = AmiAsrDataModule(args) + + train_cuts = ami.aimix_train_cuts(rvb_affix="comb", sources=True) + dev_cuts = ami.ami_cuts(split="dev", type="ihm-mix") + dev_cuts = dev_cuts.trim_to_supervision_groups(max_pause=0.0).filter( + lambda c: 0.2 <= c.duration <= 60.0 + ) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = ami.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + sources=True, + ) + valid_dl = ami.valid_dataloaders(dev_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = [sp.encode(text_ch) for text_ch in batch["text"]] + num_tokens = [sum(len(yi) for yi in y_ch) for y_ch in y] + logging.info(f"num tokens: {num_tokens}") + + +def main(): + parser = get_parser() + AmiAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") + +if __name__ == "__main__": + main() diff --git a/egs/ami/SURT/dprnn_zipformer/train_adapt.py b/egs/ami/SURT/dprnn_zipformer/train_adapt.py new file mode 100755 index 000000000..9f3b4425f --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/train_adapt.py @@ -0,0 +1,1411 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +# ./dprnn_zipformer/train.py should be run before this script. + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./dprnn_zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir dprnn_zipformer/exp_adapt \ + --model-init-ckpt dprnn_zipformer/exp/epoch-30.pt \ + --max-duration 550 +""" + +import argparse +import copy +import logging +import warnings +from itertools import chain +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AmiAsrDataModule +from decoder import Decoder +from dprnn import DPRNN +from einops.layers.torch import Rearrange +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import LOG_EPSILON, fix_random_seed +from model import SURT +from optim import Eden, ScaledAdam +from scaling import ScaledLinear, ScaledLSTM +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-mask-encoder-layers", + type=int, + default=4, + help="Number of layers in the DPRNN based mask encoder.", + ) + + parser.add_argument( + "--mask-encoder-dim", + type=int, + default=256, + help="Hidden dimension of the LSTM blocks in DPRNN.", + ) + + parser.add_argument( + "--mask-encoder-segment-size", + type=int, + default=32, + help="Segment size of the SegLSTM in DPRNN. Ideally, this should be equal to the " + "decode-chunk-length of the zipformer encoder.", + ) + + parser.add_argument( + "--chunk-width-randomization", + type=bool, + default=False, + help="Whether to randomize the chunk width in DPRNN.", + ) + + # Zipformer config is based on: + # https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740 + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,2,2,2", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="768,768,768,768,768", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="256,256,256,256,256", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="192,192,192,192,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--use-joint-encoder-layer", + type=str, + default="linear", + choices=["linear", "lstm", "none"], + help="Whether to use a joint layer to combine all branches.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conv_lstm_transducer_stateless_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--model-init-ckpt", + type=str, + default=None, + help="""The model checkpoint to initialize the model (either full or part). + If not specified, the model is randomly initialized. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.0001, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=2, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=1, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 2000, + # parameters for SURT + "num_channels": 2, + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed + # parameters for Noam + "model_warm_step": 5000, # arg given to model, not for lrate + # parameters for ctc loss + "beam_size": 10, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + + return params + + +def get_mask_encoder_model(params: AttributeDict) -> nn.Module: + mask_encoder = DPRNN( + feature_dim=params.feature_dim, + input_size=params.mask_encoder_dim, + hidden_size=params.mask_encoder_dim, + output_size=params.feature_dim * params.num_channels, + segment_size=params.mask_encoder_segment_size, + num_blocks=params.num_mask_encoder_layers, + chunk_width_randomization=params.chunk_width_randomization, + ) + return mask_encoder + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_joint_encoder_layer(params: AttributeDict) -> nn.Module: + class TakeFirst(nn.Module): + def forward(self, x): + return x[0] + + if params.use_joint_encoder_layer == "linear": + encoder_dim = int(params.encoder_dims.split(",")[-1]) + joint_layer = nn.Sequential( + Rearrange("(c b) t d -> b t (c d)", c=params.num_channels), + nn.Linear( + params.num_channels * encoder_dim, params.num_channels * encoder_dim + ), + nn.ReLU(), + Rearrange("b t (c d) -> (c b) t d", c=params.num_channels), + ) + elif params.use_joint_encoder_layer == "lstm": + encoder_dim = int(params.encoder_dims.split(",")[-1]) + joint_layer = nn.Sequential( + Rearrange("(c b) t d -> b t (c d)", c=params.num_channels), + ScaledLSTM( + input_size=params.num_channels * encoder_dim, + hidden_size=params.num_channels * encoder_dim, + num_layers=1, + bias=True, + batch_first=True, + dropout=0.0, + bidirectional=False, + ), + TakeFirst(), + nn.ReLU(), + Rearrange("b t (c d) -> (c b) t d", c=params.num_channels), + ) + elif params.use_joint_encoder_layer == "none": + joint_layer = None + else: + raise ValueError( + f"Unknown joint encoder layer type: {params.use_joint_encoder_layer}" + ) + return joint_layer + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_surt_model( + params: AttributeDict, +) -> nn.Module: + mask_encoder = get_mask_encoder_model(params) + encoder = get_encoder_model(params) + joint_layer = get_joint_encoder_layer(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = SURT( + mask_encoder=mask_encoder, + encoder=encoder, + joint_encoder_layer=joint_layer, + decoder=decoder, + joiner=joiner, + num_channels=params.num_channels, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_heat_loss(x_masked, batch, num_channels=2) -> Tensor: + """ + Compute HEAT loss for separated sources using the output of mask encoder. + Args: + x_masked: + The output of mask encoder. It is a tensor of shape (B, T, C). + batch: + A batch of data. See `lhotse.dataset.K2SurtDatasetWithSources()` + for the content in it. + num_channels: + The number of output branches in the SURT model. + """ + B, T, D = x_masked[0].shape + device = x_masked[0].device + + # Create training targets for each channel. + targets = [] + for i in range(num_channels): + target = torch.ones_like(x_masked[i]) * LOG_EPSILON + targets.append(target) + + source_feats = batch["source_feats"] + source_boundaries = batch["source_boundaries"] + input_lens = batch["input_lens"].to(device) + # Assign sources to channels based on the HEAT criteria + for b in range(B): + cut_source_feats = source_feats[b] + cut_source_boundaries = source_boundaries[b] + last_seg_end = [0 for _ in range(num_channels)] + for source_feat, (start, end) in zip(cut_source_feats, cut_source_boundaries): + assigned = False + for i in range(num_channels): + if start >= last_seg_end[i]: + targets[i][b, start:end, :] += source_feat.to(device) + last_seg_end[i] = max(end, last_seg_end[i]) + assigned = True + break + if not assigned: + min_end_channel = last_seg_end.index(min(last_seg_end)) + targets[min_end_channel][b, start:end, :] += source_feat + last_seg_end[min_end_channel] = max(end, last_seg_end[min_end_channel]) + + # Get padding mask based on input lengths + pad_mask = torch.arange(T, device=device).expand(B, T) > input_lens.unsqueeze(1) + pad_mask = pad_mask.unsqueeze(-1) + + # Compute masked loss for each channel + losses = torch.zeros((num_channels, B, T, D), device=device) + for i in range(num_channels): + loss = nn.functional.mse_loss(x_masked[i], targets[i], reduction="none") + # Apply padding mask to loss + loss.masked_fill_(pad_mask, 0) + losses[i] = loss + + # loss: C x B x T x D. pad_mask: B x T x 1 + # We want to compute loss for each item in the batch. Each item has loss given + # by the sum over C, and average over T and D. For T, we need to use the padding. + loss = losses.sum(0).mean(-1).sum(-1) / batch["input_lens"].to(device) + return loss + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"].to(device) + feature_lens = batch["input_lens"].to(device) + + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + + # The dataloader returns text as a list of cuts, each of which is a list of channel + # text. We flatten this to a list where all channels are together, i.e., it looks like + # [utt1_ch1, utt2_ch1, ..., uttN_ch1, utt1_ch2, ...., uttN,ch2]. + text = [val for tup in zip(*batch["text"]) for val in tup] + assert len(text) == len(feature) * params.num_channels + + # Convert all channel texts to token IDs and create a ragged tensor. + y = sp.encode(text, out_type=int) + y = k2.RaggedTensor(y).to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.model_warm_step + + with torch.set_grad_enabled(is_training): + (simple_loss, pruned_loss, ctc_loss, x_masked) = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + reduction="none", + subsampling_factor=params.subsampling_factor, + ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + ctc_loss_is_finite = torch.isfinite(ctc_loss) + + # Compute HEAT loss + if is_training and params.heat_loss_scale > 0.0: + heat_loss = compute_heat_loss( + x_masked, batch, num_channels=params.num_channels + ) + else: + heat_loss = torch.tensor(0.0, device=device) + + heat_loss_is_finite = torch.isfinite(heat_loss) + is_finite = ( + simple_loss_is_finite + & pruned_loss_is_finite + & ctc_loss_is_finite + & heat_loss_is_finite + ) + if not torch.all(is_finite): + # logging.info( + # "Not all losses are finite!\n" + # f"simple_losses: {simple_loss}\n" + # f"pruned_losses: {pruned_loss}\n" + # f"ctc_losses: {ctc_loss}\n" + # f"heat_losses: {heat_loss}\n" + # ) + # display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + ctc_loss = ctc_loss[ctc_loss_is_finite] + heat_loss = heat_loss[heat_loss_is_finite] + + # If either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if ( + torch.all(~simple_loss_is_finite) + or torch.all(~pruned_loss_is_finite) + or torch.all(~ctc_loss_is_finite) + or torch.all(~heat_loss_is_finite) + ): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss_sum = simple_loss.sum() + pruned_loss_sum = pruned_loss.sum() + ctc_loss_sum = ctc_loss.sum() + heat_loss_sum = heat_loss.sum() + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss = ( + simple_loss_scale * simple_loss_sum + + pruned_loss_scale * pruned_loss_sum + + params.ctc_loss_scale * ctc_loss_sum + + params.heat_loss_scale * heat_loss_sum + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss_sum.detach().cpu().item() + info["pruned_loss"] = pruned_loss_sum.detach().cpu().item() + if params.ctc_loss_scale > 0.0: + info["ctc_loss"] = ctc_loss_sum.detach().cpu().item() + if params.heat_loss_scale > 0.0: + info["heat_loss"] = heat_loss_sum.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + torch.cuda.empty_cache() + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = batch["inputs"].shape[0] + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_surt_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + + if checkpoints is None and params.model_init_ckpt is not None: + logging.info( + f"Initializing model with checkpoint from {params.model_init_ckpt}" + ) + init_ckpt = torch.load(params.model_init_ckpt, map_location=device) + model.load_state_dict(init_ckpt["model"], strict=False) + + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + ami = AmiAsrDataModule(args) + + train_cuts = ami.train_cuts() + train_cuts = train_cuts.filter(lambda c: 0.5 <= c.duration <= 35.0) + dev_cuts = ami.ami_cuts(split="dev", type="ihm-mix") + dev_cuts = dev_cuts.trim_to_supervision_groups(max_pause=0.0).filter( + lambda c: 0.2 <= c.duration <= 60.0 + ) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = ami.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + ) + valid_dl = ami.valid_dataloaders(dev_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = [sp.encode(text_ch) for text_ch in batch["text"]] + num_tokens = [sum(len(yi) for yi in y_ch) for y_ch in y] + logging.info(f"num tokens: {num_tokens}") + + +def main(): + parser = get_parser() + AmiAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") + +if __name__ == "__main__": + main() diff --git a/egs/ami/SURT/dprnn_zipformer/zipformer.py b/egs/ami/SURT/dprnn_zipformer/zipformer.py new file mode 120000 index 000000000..59b772024 --- /dev/null +++ b/egs/ami/SURT/dprnn_zipformer/zipformer.py @@ -0,0 +1 @@ +../../../libricss/SURT/dprnn_zipformer/zipformer.py \ No newline at end of file diff --git a/egs/ami/SURT/local/add_source_feats.py b/egs/ami/SURT/local/add_source_feats.py new file mode 100755 index 000000000..0917b88a6 --- /dev/null +++ b/egs/ami/SURT/local/add_source_feats.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins University (authors: Desh Raj) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file adds source features as temporal arrays to the mixture manifests. +It looks for manifests in the directory data/manifests. +""" +import logging +from pathlib import Path + +import numpy as np +from lhotse import CutSet, LilcomChunkyWriter, load_manifest, load_manifest_lazy +from tqdm import tqdm + + +def add_source_feats(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + logging.info("Reading mixed cuts") + mixed_cuts_clean = load_manifest_lazy(src_dir / "cuts_train_clean.jsonl.gz") + mixed_cuts_reverb = load_manifest_lazy(src_dir / "cuts_train_reverb.jsonl.gz") + + logging.info("Reading source cuts") + source_cuts = load_manifest(src_dir / "ihm_cuts_train_trimmed.jsonl.gz") + + logging.info("Adding source features to the mixed cuts") + pbar = tqdm(total=len(mixed_cuts_clean), desc="Adding source features") + with CutSet.open_writer( + src_dir / "cuts_train_clean_sources.jsonl.gz" + ) as cut_writer_clean, CutSet.open_writer( + src_dir / "cuts_train_reverb_sources.jsonl.gz" + ) as cut_writer_reverb, LilcomChunkyWriter( + output_dir / "feats_train_clean_sources" + ) as source_feat_writer: + for cut_clean, cut_reverb in zip(mixed_cuts_clean, mixed_cuts_reverb): + assert cut_reverb.id == cut_clean.id + "_rvb" + source_feats = [] + source_feat_offsets = [] + cur_offset = 0 + for sup in sorted( + cut_clean.supervisions, key=lambda s: (s.start, s.speaker) + ): + source_cut = source_cuts[sup.id] + source_feats.append(source_cut.load_features()) + source_feat_offsets.append(cur_offset) + cur_offset += source_cut.num_frames + cut_clean.source_feats = source_feat_writer.store_array( + cut_clean.id, np.concatenate(source_feats, axis=0) + ) + cut_clean.source_feat_offsets = source_feat_offsets + cut_writer_clean.write(cut_clean) + # Also write the reverb cut + cut_reverb.source_feats = cut_clean.source_feats + cut_reverb.source_feat_offsets = cut_clean.source_feat_offsets + cut_writer_reverb.write(cut_reverb) + pbar.update(1) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + add_source_feats() diff --git a/egs/ami/SURT/local/compute_fbank_aimix.py b/egs/ami/SURT/local/compute_fbank_aimix.py new file mode 100755 index 000000000..91b3a060b --- /dev/null +++ b/egs/ami/SURT/local/compute_fbank_aimix.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins University (authors: Desh Raj) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the synthetically mixed AMI and ICSI +train set. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" +import logging +import random +import warnings +from pathlib import Path + +import torch +import torch.multiprocessing +import torchaudio +from lhotse import ( + AudioSource, + LilcomChunkyWriter, + Recording, + load_manifest, + load_manifest_lazy, +) +from lhotse.audio import set_ffmpeg_torchaudio_info_enabled +from lhotse.cut import MixedCut, MixTrack, MultiCut +from lhotse.features.kaldifeat import ( + KaldifeatFbank, + KaldifeatFbankConfig, + KaldifeatFrameOptions, + KaldifeatMelOptions, +) +from lhotse.utils import fix_random_seed, uuid4 +from tqdm import tqdm + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") +torchaudio.set_audio_backend("soundfile") +set_ffmpeg_torchaudio_info_enabled(False) + + +def compute_fbank_aimix(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + sampling_rate = 16000 + num_mel_bins = 80 + + extractor = KaldifeatFbank( + KaldifeatFbankConfig( + frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), + mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), + device="cuda", + ) + ) + + logging.info("Reading manifests") + train_cuts = load_manifest_lazy(src_dir / "ai-mix_cuts_clean_full.jsonl.gz") + + # only uses RIRs and noises from REVERB challenge + real_rirs = load_manifest(src_dir / "real-rir_recordings_all.jsonl.gz").filter( + lambda r: "RVB2014" in r.id + ) + noises = load_manifest(src_dir / "iso-noise_recordings_all.jsonl.gz").filter( + lambda r: "RVB2014" in r.id + ) + + # Apply perturbation to the training cuts + logging.info("Applying perturbation to the training cuts") + train_cuts_rvb = train_cuts.map( + lambda c: augment( + c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True + ) + ) + + logging.info("Extracting fbank features for training cuts") + _ = train_cuts.compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / "ai-mix_feats_clean", + manifest_path=src_dir / "cuts_train_clean.jsonl.gz", + batch_duration=5000, + num_workers=4, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + _ = train_cuts_rvb.compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / "ai-mix_feats_reverb", + manifest_path=src_dir / "cuts_train_reverb.jsonl.gz", + batch_duration=5000, + num_workers=4, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + +def augment(cut, perturb_snr=False, rirs=None, noises=None, perturb_loudness=False): + """ + Given a mixed cut, this function optionally applies the following augmentations: + - Perturbing the SNRs of the tracks (in range [-5, 5] dB) + - Reverberation using a randomly selected RIR + - Adding noise + - Perturbing the loudness (in range [-20, -25] dB) + """ + out_cut = cut.drop_features() + + # Perturb the SNRs (optional) + if perturb_snr: + snrs = [random.uniform(-5, 5) for _ in range(len(cut.tracks))] + for i, (track, snr) in enumerate(zip(out_cut.tracks, snrs)): + if i == 0: + # Skip the first track since it is the reference + continue + track.snr = snr + + # Reverberate the cut (optional) + if rirs is not None: + # Select an RIR at random + rir = random.choice(rirs) + # Select a channel at random + rir_channel = random.choice(list(range(rir.num_channels))) + # Reverberate the cut + out_cut = out_cut.reverb_rir(rir_recording=rir, rir_channels=[rir_channel]) + + # Add noise (optional) + if noises is not None: + # Select a noise recording at random + noise = random.choice(noises).to_cut() + if isinstance(noise, MultiCut): + noise = noise.to_mono()[0] + # Select an SNR at random + snr = random.uniform(10, 30) + # Repeat the noise to match the duration of the cut + noise = repeat_cut(noise, out_cut.duration) + out_cut = MixedCut( + id=out_cut.id, + tracks=[ + MixTrack(cut=out_cut, type="MixedCut"), + MixTrack(cut=noise, type="DataCut", snr=snr), + ], + ) + + # Perturb the loudness (optional) + if perturb_loudness: + target_loudness = random.uniform(-20, -25) + out_cut = out_cut.normalize_loudness(target_loudness, mix_first=True) + return out_cut + + +def repeat_cut(cut, duration): + while cut.duration < duration: + cut = cut.mix(cut, offset_other_by=cut.duration) + return cut.truncate(duration=duration) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + fix_random_seed(42) + compute_fbank_aimix() diff --git a/egs/ami/SURT/local/compute_fbank_ami.py b/egs/ami/SURT/local/compute_fbank_ami.py new file mode 100755 index 000000000..351b41765 --- /dev/null +++ b/egs/ami/SURT/local/compute_fbank_ami.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins University (authors: Desh Raj) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the AMI dataset. +We compute features for full recordings (i.e., without trimming to supervisions). +This way we can create arbitrary segmentations later. + +The generated fbank features are saved in data/fbank. +""" +import logging +import math +from pathlib import Path + +import torch +import torch.multiprocessing +from lhotse import CutSet, LilcomChunkyWriter +from lhotse.features.kaldifeat import ( + KaldifeatFbank, + KaldifeatFbankConfig, + KaldifeatFrameOptions, + KaldifeatMelOptions, +) +from lhotse.recipes.utils import read_manifests_if_cached + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") + + +def compute_fbank_ami(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + sampling_rate = 16000 + num_mel_bins = 80 + + extractor = KaldifeatFbank( + KaldifeatFbankConfig( + frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), + mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), + device="cuda", + ) + ) + + logging.info("Reading manifests") + manifests = {} + for part in ["ihm-mix", "sdm", "mdm8-bf"]: + manifests[part] = read_manifests_if_cached( + dataset_parts=["train", "dev", "test"], + output_dir=src_dir, + prefix=f"ami-{part}", + suffix="jsonl.gz", + ) + + for part in ["ihm-mix", "sdm", "mdm8-bf"]: + for split in ["train", "dev", "test"]: + logging.info(f"Processing {part} {split}") + cuts = CutSet.from_manifests( + **manifests[part][split] + ).compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / f"ami-{part}_{split}_feats", + manifest_path=src_dir / f"cuts_ami-{part}_{split}.jsonl.gz", + batch_duration=5000, + num_workers=4, + storage_type=LilcomChunkyWriter, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_fbank_ami() diff --git a/egs/ami/SURT/local/compute_fbank_icsi.py b/egs/ami/SURT/local/compute_fbank_icsi.py new file mode 100755 index 000000000..4e2ff3f3b --- /dev/null +++ b/egs/ami/SURT/local/compute_fbank_icsi.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins University (authors: Desh Raj) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the ICSI dataset. +We compute features for full recordings (i.e., without trimming to supervisions). +This way we can create arbitrary segmentations later. + +The generated fbank features are saved in data/fbank. +""" +import logging +import math +from pathlib import Path + +import torch +import torch.multiprocessing +from lhotse import CutSet, LilcomChunkyWriter +from lhotse.features.kaldifeat import ( + KaldifeatFbank, + KaldifeatFbankConfig, + KaldifeatFrameOptions, + KaldifeatMelOptions, +) +from lhotse.recipes.utils import read_manifests_if_cached + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") + + +def compute_fbank_icsi(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + sampling_rate = 16000 + num_mel_bins = 80 + + extractor = KaldifeatFbank( + KaldifeatFbankConfig( + frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), + mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), + device="cuda", + ) + ) + + logging.info("Reading manifests") + manifests = {} + for part in ["ihm-mix", "sdm"]: + manifests[part] = read_manifests_if_cached( + dataset_parts=["train"], + output_dir=src_dir, + prefix=f"icsi-{part}", + suffix="jsonl.gz", + ) + + for part in ["ihm-mix", "sdm"]: + for split in ["train"]: + logging.info(f"Processing {part} {split}") + cuts = CutSet.from_manifests( + **manifests[part][split] + ).compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / f"icsi-{part}_{split}_feats", + manifest_path=src_dir / f"cuts_icsi-{part}_{split}.jsonl.gz", + batch_duration=5000, + num_workers=4, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_fbank_icsi() diff --git a/egs/ami/SURT/local/compute_fbank_ihm.py b/egs/ami/SURT/local/compute_fbank_ihm.py new file mode 100755 index 000000000..56f54aa21 --- /dev/null +++ b/egs/ami/SURT/local/compute_fbank_ihm.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins University (authors: Desh Raj) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the trimmed sub-segments which will be +used for simulating the training mixtures. + +The generated fbank features are saved in data/fbank. +""" +import logging +import math +from pathlib import Path + +import torch +import torch.multiprocessing +import torchaudio +from lhotse import CutSet, LilcomChunkyWriter, load_manifest +from lhotse.audio import set_ffmpeg_torchaudio_info_enabled +from lhotse.features.kaldifeat import ( + KaldifeatFbank, + KaldifeatFbankConfig, + KaldifeatFrameOptions, + KaldifeatMelOptions, +) +from lhotse.recipes.utils import read_manifests_if_cached +from tqdm import tqdm + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") +torchaudio.set_audio_backend("soundfile") +set_ffmpeg_torchaudio_info_enabled(False) + + +def compute_fbank_ihm(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + sampling_rate = 16000 + num_mel_bins = 80 + + extractor = KaldifeatFbank( + KaldifeatFbankConfig( + frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), + mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), + device="cuda", + ) + ) + + logging.info("Reading manifests") + manifests = {} + for data in ["ami", "icsi"]: + manifests[data] = read_manifests_if_cached( + dataset_parts=["train"], + output_dir=src_dir, + types=["recordings", "supervisions"], + prefix=f"{data}-ihm", + suffix="jsonl.gz", + ) + + logging.info("Computing features") + for data in ["ami", "icsi"]: + cs = CutSet.from_manifests(**manifests[data]["train"]) + cs = cs.trim_to_supervisions(keep_overlapping=False) + cs = cs.normalize_loudness(target=-23.0, affix_id=False) + cs = cs + cs.perturb_speed(0.9) + cs.perturb_speed(1.1) + _ = cs.compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / f"{data}-ihm_train_feats", + manifest_path=src_dir / f"{data}-ihm_cuts_train.jsonl.gz", + batch_duration=5000, + num_workers=4, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_fbank_ihm() diff --git a/egs/ami/SURT/local/prepare_ami_train_cuts.py b/egs/ami/SURT/local/prepare_ami_train_cuts.py new file mode 100755 index 000000000..72fced70d --- /dev/null +++ b/egs/ami/SURT/local/prepare_ami_train_cuts.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins University (authors: Desh Raj) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file creates AMI train segments. +""" +import logging +import math +from pathlib import Path + +import torch +import torch.multiprocessing +from lhotse import LilcomChunkyWriter, load_manifest_lazy +from lhotse.cut import Cut, CutSet +from lhotse.utils import EPSILON, add_durations +from tqdm import tqdm + + +def cut_into_windows(cuts: CutSet, duration: float): + """ + This function takes a CutSet and cuts each cut into windows of roughly + `duration` seconds. By roughly, we mean that we try to adjust for the last supervision + that exceeds the duration, or is shorter than the duration. + """ + res = [] + with tqdm() as pbar: + for cut in cuts: + pbar.update(1) + sups = cut.index_supervisions()[cut.id] + sr = cut.sampling_rate + start = 0.0 + end = duration + num_tries = 0 + while start < cut.duration and num_tries < 2: + # Find the supervision that are cut by the window endpoint + hitlist = [iv for iv in sups.at(end) if iv.begin < end] + # If there are no supervisions, we are done + if not hitlist: + res.append( + cut.truncate( + offset=start, + duration=add_durations(end, -start, sampling_rate=sr), + keep_excessive_supervisions=False, + ) + ) + # Update the start and end for the next window + start = end + end = add_durations(end, duration, sampling_rate=sr) + else: + # find ratio of durations cut by the window endpoint + ratios = [ + add_durations(end, -iv.end, sampling_rate=sr) / iv.length() + for iv in hitlist + ] + # we retain the supervisions that have >50% of their duration + # in the window, and discard the others + retained = [] + discarded = [] + for iv, ratio in zip(hitlist, ratios): + if ratio > 0.5: + retained.append(iv) + else: + discarded.append(iv) + cur_end = max(iv.end for iv in retained) if retained else end + res.append( + cut.truncate( + offset=start, + duration=add_durations(cur_end, -start, sampling_rate=sr), + keep_excessive_supervisions=False, + ) + ) + # For the next window, we start at the earliest discarded supervision + next_start = min(iv.begin for iv in discarded) if discarded else end + next_end = add_durations(next_start, duration, sampling_rate=sr) + # It may happen that next_start is the same as start, in which case + # we will advance the window anyway + if next_start == start: + logging.warning( + f"Next start is the same as start: {next_start} == {start} for cut {cut.id}" + ) + start = end + EPSILON + end = add_durations(start, duration, sampling_rate=sr) + num_tries += 1 + else: + start = next_start + end = next_end + return CutSet.from_cuts(res) + + +def prepare_train_cuts(): + src_dir = Path("data/manifests") + + logging.info("Loading the manifests") + train_cuts_ihm = load_manifest_lazy( + src_dir / "cuts_ami-ihm-mix_train.jsonl.gz" + ).map(lambda c: c.with_id(f"{c.id}_ihm-mix")) + train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_ami-sdm_train.jsonl.gz").map( + lambda c: c.with_id(f"{c.id}_sdm") + ) + train_cuts_mdm = load_manifest_lazy( + src_dir / "cuts_ami-mdm8-bf_train.jsonl.gz" + ).map(lambda c: c.with_id(f"{c.id}_mdm8-bf")) + + # Combine all cuts into one CutSet + train_cuts = train_cuts_ihm + train_cuts_sdm + train_cuts_mdm + + train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5) + train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0) + + # Combine the two segmentations + train_all = train_cuts_1 + train_cuts_2 + + # At this point, some of the cuts may be very long. We will cut them into windows of + # roughly 30 seconds. + logging.info("Cutting the segments into windows of 30 seconds") + train_all_30 = cut_into_windows(train_all, duration=30.0) + logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}") + + # Show statistics + train_all.describe(full=True) + + # Save the cuts + logging.info("Saving the cuts") + train_all.to_file(src_dir / "cuts_train_ami.jsonl.gz") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_train_cuts() diff --git a/egs/ami/SURT/local/prepare_icsi_train_cuts.py b/egs/ami/SURT/local/prepare_icsi_train_cuts.py new file mode 100755 index 000000000..818e26bfb --- /dev/null +++ b/egs/ami/SURT/local/prepare_icsi_train_cuts.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins University (authors: Desh Raj) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file creates ICSI train segments. +""" +import logging +from pathlib import Path + +from lhotse import load_manifest_lazy +from prepare_ami_train_cuts import cut_into_windows + + +def prepare_train_cuts(): + src_dir = Path("data/manifests") + + logging.info("Loading the manifests") + train_cuts_ihm = load_manifest_lazy( + src_dir / "cuts_icsi-ihm-mix_train.jsonl.gz" + ).map(lambda c: c.with_id(f"{c.id}_ihm-mix")) + train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_icsi-sdm_train.jsonl.gz").map( + lambda c: c.with_id(f"{c.id}_sdm") + ) + + # Combine all cuts into one CutSet + train_cuts = train_cuts_ihm + train_cuts_sdm + + train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5) + train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0) + + # Combine the two segmentations + train_all = train_cuts_1 + train_cuts_2 + + # At this point, some of the cuts may be very long. We will cut them into windows of + # roughly 30 seconds. + logging.info("Cutting the segments into windows of 30 seconds") + train_all_30 = cut_into_windows(train_all, duration=30.0) + logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}") + + # Show statistics + train_all.describe(full=True) + + # Save the cuts + logging.info("Saving the cuts") + train_all.to_file(src_dir / "cuts_train_icsi.jsonl.gz") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_train_cuts() diff --git a/egs/ami/SURT/local/prepare_lang_bpe.py b/egs/ami/SURT/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/ami/SURT/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/ami/SURT/local/train_bpe_model.py b/egs/ami/SURT/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/ami/SURT/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/ami/SURT/prepare.sh b/egs/ami/SURT/prepare.sh new file mode 100755 index 000000000..ea4e5baf2 --- /dev/null +++ b/egs/ami/SURT/prepare.sh @@ -0,0 +1,195 @@ +#!/usr/bin/env bash + +set -eou pipefail + +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/ami +# You can find audio and transcripts for AMI in this path. +# +# - $dl_dir/icsi +# You can find audio and transcripts for ICSI in this path. +# +# - $dl_dir/rirs_noises +# This directory contains the RIRS_NOISES corpus downloaded from https://openslr.org/28/. +# +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +vocab_size=500 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/amicorpus, + # you can create a symlink + # + # ln -sfv /path/to/amicorpus $dl_dir/amicorpus + # + if [ ! -d $dl_dir/amicorpus ]; then + for mic in ihm ihm-mix sdm mdm8-bf; do + lhotse download ami --mic $mic $dl_dir/amicorpus + done + fi + + # If you have pre-downloaded it to /path/to/icsi, + # you can create a symlink + # + # ln -sfv /path/to/icsi $dl_dir/icsi + # + if [ ! -d $dl_dir/icsi ]; then + lhotse download icsi $dl_dir/icsi + fi + + # If you have pre-downloaded it to /path/to/rirs_noises, + # you can create a symlink + # + # ln -sfv /path/to/rirs_noises $dl_dir/ + # + if [ ! -d $dl_dir/rirs_noises ]; then + lhotse download rirs_noises $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare AMI manifests" + # We assume that you have downloaded the AMI corpus + # to $dl_dir/amicorpus. We perform text normalization for the transcripts. + mkdir -p data/manifests + for mic in ihm ihm-mix sdm mdm8-bf; do + log "Preparing AMI manifest for $mic" + lhotse prepare ami --mic $mic --max-words-per-segment 30 --merge-consecutive $dl_dir/amicorpus data/manifests/ + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare ICSI manifests" + # We assume that you have downloaded the ICSI corpus + # to $dl_dir/icsi. We perform text normalization for the transcripts. + mkdir -p data/manifests + log "Preparing ICSI manifest" + for mic in ihm ihm-mix sdm; do + lhotse prepare icsi --mic $mic $dl_dir/icsi data/manifests/ + done +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare RIRs" + # We assume that you have downloaded the RIRS_NOISES corpus + # to $dl_dir/rirs_noises + lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 3: Extract features for AMI and ICSI recordings" + python local/compute_fbank_ami.py + python local/compute_fbank_icsi.py +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Create sources for simulating mixtures" + # In the following script, we speed-perturb the IHM recordings and extract features. + python local/compute_fbank_ihm.py + lhotse combine data/manifests/ami-ihm_cuts_train.jsonl.gz \ + data/manifests/icsi-ihm_cuts_train.jsonl.gz - |\ + lhotse cut trim-to-alignments --type word --max-pause 0.5 - - |\ + lhotse filter 'duration<=12.0' - - |\ + shuf | gzip -c > data/manifests/ihm_cuts_train_trimmed.jsonl.gz +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Create training mixtures" + lhotse workflows simulate-meetings \ + --method conversational \ + --same-spk-pause 0.5 \ + --diff-spk-pause 0.5 \ + --diff-spk-overlap 1.0 \ + --prob-diff-spk-overlap 0.8 \ + --num-meetings 200000 \ + --num-speakers-per-meeting 2,3 \ + --max-duration-per-speaker 15.0 \ + --max-utterances-per-speaker 3 \ + --seed 1234 \ + --num-jobs 2 \ + data/manifests/ihm_cuts_train_trimmed.jsonl.gz \ + data/manifests/ai-mix_cuts_clean.jsonl.gz + + python local/compute_fbank_aimix.py + + # Add source features to the manifest (will be used for masking loss) + # This may take ~2 hours. + python local/add_source_feats.py + + # Combine clean and reverb + cat <(gunzip -c data/manifests/cuts_train_clean_sources.jsonl.gz) \ + <(gunzip -c data/manifests/cuts_train_reverb_sources.jsonl.gz) |\ + shuf | gzip -c > data/manifests/cuts_train_comb_sources.jsonl.gz +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Create training mixtures from real sessions" + python local/prepare_ami_train_cuts.py + python local/prepare_icsi_train_cuts.py + + # Combine AMI and ICSI + cat <(gunzip -c data/manifests/cuts_train_ami.jsonl.gz) \ + <(gunzip -c data/manifests/cuts_train_icsi.jsonl.gz) |\ + shuf | gzip -c > data/manifests/cuts_train_ami_icsi.jsonl.gz +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Dump transcripts for BPE model training (using AMI and ICSI)." + mkdir -p data/lm + cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \ + <(gunzip -c data/manifests/icsi-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \ + > data/lm/transcript_words.txt +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Prepare BPE based lang (combining AMI and ICSI)" + + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + # Add special words to words.txt + echo " 0" > $lang_dir/words.txt + echo "!SIL 1" >> $lang_dir/words.txt + echo " 2" >> $lang_dir/words.txt + + # Add regular words to words.txt + cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt + + # Add remaining special word symbols expected by LM scripts. + num_words=$(cat $lang_dir/words.txt | wc -l) + echo " ${num_words}" >> $lang_dir/words.txt + num_words=$(cat $lang_dir/words.txt | wc -l) + echo " ${num_words}" >> $lang_dir/words.txt + num_words=$(cat $lang_dir/words.txt | wc -l) + echo "#0 ${num_words}" >> $lang_dir/words.txt + + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript data/lm/transcript_words.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + fi +fi diff --git a/egs/ami/SURT/shared b/egs/ami/SURT/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/ami/SURT/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file From 5ed6fc0e6d9afeebaf86ec83c16d9ff2c8d6a0ba Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 12 Jul 2023 15:37:14 +0800 Subject: [PATCH 004/113] add sym link (#1170) --- egs/wenetspeech/ASR/local/prepare_char_lm_training_data.py | 1 + egs/wenetspeech/ASR/local/sort_lm_training_data.py | 1 + 2 files changed, 2 insertions(+) create mode 120000 egs/wenetspeech/ASR/local/prepare_char_lm_training_data.py create mode 120000 egs/wenetspeech/ASR/local/sort_lm_training_data.py diff --git a/egs/wenetspeech/ASR/local/prepare_char_lm_training_data.py b/egs/wenetspeech/ASR/local/prepare_char_lm_training_data.py new file mode 120000 index 000000000..2374cafdd --- /dev/null +++ b/egs/wenetspeech/ASR/local/prepare_char_lm_training_data.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_char_lm_training_data.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/local/sort_lm_training_data.py b/egs/wenetspeech/ASR/local/sort_lm_training_data.py new file mode 120000 index 000000000..efef2c445 --- /dev/null +++ b/egs/wenetspeech/ASR/local/sort_lm_training_data.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/sort_lm_training_data.py \ No newline at end of file From 4ab7d610081c0c3b38dd851298cb45381e6ac591 Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Sat, 15 Jul 2023 12:39:32 +0800 Subject: [PATCH 005/113] removed `batch_name` to fix a KeyError with "uttid" (#1172) --- egs/librispeech/ASR/conformer_ctc2/train.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 3366af13e..c4a13b101 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -675,7 +675,6 @@ def train_one_epoch( for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - batch_name = batch["supervisions"]["uttid"] with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( @@ -698,10 +697,7 @@ def train_one_epoch( scaler.scale(loss).backward() except RuntimeError as e: if "CUDA out of memory" in str(e): - logging.error( - f"failing batch size:{batch_size} " - f"failing batch names {batch_name}" - ) + logging.error(f"failing batch size:{batch_size} ") raise scheduler.step_batch(params.batch_idx_train) @@ -756,10 +752,7 @@ def train_one_epoch( if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float( "inf" ): - logging.error( - "Your loss contains inf, something goes wrong" - f"failing batch names {batch_name}" - ) + logging.error("Your loss contains inf, something goes wrong") if tb_writer is not None: tb_writer.add_scalar( "train/learning_rate", cur_lr, params.batch_idx_train From 1dbbd7759ef707eca36bb899bcea8e32afc52282 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 25 Jul 2023 14:46:18 +0800 Subject: [PATCH 006/113] Add tests for subsample.py and fix typos (#1180) --- .github/workflows/test.yml | 57 ++----- .../pruned_transducer_stateless2/conformer.py | 2 + .../pruned_transducer_stateless3/test_onnx.py | 6 +- .../pruned_transducer_stateless7/test_onnx.py | 3 +- egs/librispeech/ASR/zipformer/.gitignore | 1 + egs/librispeech/ASR/zipformer/model.py | 2 +- egs/librispeech/ASR/zipformer/scaling.py | 14 +- egs/librispeech/ASR/zipformer/subsampling.py | 23 +-- egs/librispeech/ASR/zipformer/test_scaling.py | 82 ++++++++++ .../ASR/zipformer/test_subsampling.py | 152 ++++++++++++++++++ egs/librispeech/ASR/zipformer/zipformer.py | 4 +- 11 files changed, 276 insertions(+), 70 deletions(-) create mode 100644 egs/librispeech/ASR/zipformer/.gitignore create mode 100755 egs/librispeech/ASR/zipformer/test_scaling.py create mode 100755 egs/librispeech/ASR/zipformer/test_subsampling.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e04fb5655..363556bb7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -35,9 +35,9 @@ jobs: matrix: os: [ubuntu-latest] python-version: ["3.8"] - torch: ["1.10.0"] - torchaudio: ["0.10.0"] - k2-version: ["1.23.2.dev20221201"] + torch: ["1.13.0"] + torchaudio: ["0.13.0"] + k2-version: ["1.24.3.dev20230719"] fail-fast: false @@ -66,14 +66,14 @@ jobs: pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ + pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.github.io/k2/cpu.html pip install git+https://github.com/lhotse-speech/lhotse # icefall requirements pip uninstall -y protobuf pip install --no-binary protobuf protobuf==3.20.* pip install kaldifst - pip install onnxruntime + pip install onnxruntime matplotlib pip install -r requirements.txt - name: Install graphviz @@ -83,13 +83,6 @@ jobs: python3 -m pip install -qq graphviz sudo apt-get -qq install graphviz - - name: Install graphviz - if: startsWith(matrix.os, 'macos') - shell: bash - run: | - python3 -m pip install -qq graphviz - brew install -q graphviz - - name: Run tests if: startsWith(matrix.os, 'ubuntu') run: | @@ -129,40 +122,10 @@ jobs: cd ../transducer_lstm pytest -v -s - - name: Run tests - if: startsWith(matrix.os, 'macos') - run: | - ls -lh - export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH - lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") - echo "lib_path: $lib_path" - export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH - pytest -v -s ./test - - # run tests for conformer ctc - cd egs/librispeech/ASR/conformer_ctc + cd ../zipformer pytest -v -s - cd ../pruned_transducer_stateless - pytest -v -s - - cd ../pruned_transducer_stateless2 - pytest -v -s - - cd ../pruned_transducer_stateless3 - pytest -v -s - - cd ../pruned_transducer_stateless4 - pytest -v -s - - cd ../transducer_stateless - pytest -v -s - - # cd ../transducer - # pytest -v -s - - cd ../transducer_stateless2 - pytest -v -s - - cd ../transducer_lstm - pytest -v -s + - uses: actions/upload-artifact@v2 + with: + path: egs/librispeech/ASR/zipformer/swoosh.pdf + name: swoosh.pdf diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 9bac46004..bcd419fb7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -849,6 +849,8 @@ class RelPositionalEncoding(torch.nn.Module): torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ + if isinstance(left_context, torch.Tensor): + left_context = left_context.item() self.extend_pe(x, left_context) x_size_1 = x.size(1) + left_context pos_emb = self.pe[ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 598fcf344..810da8da6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -113,7 +113,7 @@ def test_rel_pos(): torch.onnx.export( encoder_pos, - x, + (x, torch.zeros(1, dtype=torch.int64)), filename, verbose=False, opset_version=opset_version, @@ -139,7 +139,9 @@ def test_rel_pos(): assert input_nodes[0].name == "x" assert input_nodes[0].shape == ["N", "T", num_features] - inputs = {input_nodes[0].name: x.numpy()} + inputs = { + input_nodes[0].name: x.numpy(), + } onnx_y, onnx_pos_emb = session.run(["y", "pos_emb"], inputs) onnx_y = torch.from_numpy(onnx_y) onnx_pos_emb = torch.from_numpy(onnx_pos_emb) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py index 2440d267c..1e9b67226 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py @@ -265,7 +265,7 @@ def test_zipformer_encoder(): torch.onnx.export( encoder, - (x), + (x, torch.ones(1, dtype=torch.float32)), filename, verbose=False, opset_version=opset_version, @@ -289,6 +289,7 @@ def test_zipformer_encoder(): input_nodes = session.get_inputs() inputs = { input_nodes[0].name: x.numpy(), + input_nodes[1].name: torch.ones(1, dtype=torch.float32).numpy(), } onnx_y = session.run(["y"], inputs)[0] onnx_y = torch.from_numpy(onnx_y) diff --git a/egs/librispeech/ASR/zipformer/.gitignore b/egs/librispeech/ASR/zipformer/.gitignore new file mode 100644 index 000000000..e47ac1582 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/.gitignore @@ -0,0 +1 @@ +swoosh.pdf diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index b541ee697..f2f86af47 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -320,7 +320,7 @@ class AsrModel(nn.Module): assert x_lens.ndim == 1, x_lens.shape assert y.num_axes == 2, y.num_axes - assert x.size(0) == x_lens.size(0) == y.dim0 + assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) # Compute encoder outputs encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 4ee7b7826..7c98ef045 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -125,7 +125,7 @@ class PiecewiseLinear(object): p: 'PiecewiseLinear', include_crossings: bool = False): """ - Returns (self_mod, p_mod) which are equivalent piecewise lienar + Returns (self_mod, p_mod) which are equivalent piecewise linear functions to self and p, but with the same x values. p: the other piecewise linear function @@ -166,7 +166,7 @@ class ScheduledFloat(torch.nn.Module): in, float(parent_module.whatever), and use it as something like a dropout prob. It is a floating point value whose value changes depending on the batch count of the - training loop. It is a piecewise linear function where you specifiy the (x,y) pairs + training loop. It is a piecewise linear function where you specify the (x,y) pairs in sorted order on x; x corresponds to the batch index. For batch-index values before the first x or after the last x, we just use the first or last y value. @@ -343,7 +343,7 @@ class MaxEigLimiterFunction(torch.autograd.Function): class BiasNormFunction(torch.autograd.Function): # This computes: # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() - # return (x - bias) * scales + # return x * scales # (after unsqueezing the bias), but it does it in a memory-efficient way so that # it can just store the returned value (chances are, this will also be needed for # some other reason, related to the next operation, so we can save memory). @@ -400,8 +400,8 @@ class BiasNorm(torch.nn.Module): Args: num_channels: the number of channels, e.g. 512. channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of + interpreted as an offset from the input's ndim if negative. + This is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. log_scale: the initial log-scale that we multiply the output by; this is learnable. @@ -1286,7 +1286,7 @@ class Dropout3(nn.Module): class SwooshLFunction(torch.autograd.Function): """ - swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 """ @staticmethod @@ -1361,7 +1361,7 @@ class SwooshLOnnx(torch.nn.Module): class SwooshRFunction(torch.autograd.Function): """ - swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 derivatives are between -0.08 and 0.92. """ diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index d6bf57db4..6532ddccb 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -138,9 +138,11 @@ class ConvNeXt(nn.Module): x = bypass + x x = self.out_balancer(x) - x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last - x = self.out_whiten(x) - x = x.transpose(1, 3) # (N, C, H, W) + + if x.requires_grad: + x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last + x = self.out_whiten(x) + x = x.transpose(1, 3) # (N, C, H, W) return x @@ -266,6 +268,7 @@ class Conv2dSubsampling(nn.Module): # just one convnext layer self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) + # (in_channels-3)//4 self.out_width = (((in_channels - 1) // 2) - 1) // 2 self.layer3_channels = layer3_channels @@ -299,7 +302,7 @@ class Conv2dSubsampling(nn.Module): A tensor of shape (batch_size,) containing the number of frames in Returns: - - a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + - a tensor of shape (N, (T-7)//2, odim) - output lengths, of shape (batch_size,) """ # On entry, x is (N, T, idim) @@ -310,14 +313,14 @@ class Conv2dSubsampling(nn.Module): x = self.conv(x) x = self.convnext(x) - # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) + # Now x is of shape (N, odim, (T-7)//2, (idim-3)//4) b, c, t, f = x.size() x = x.transpose(1, 2).reshape(b, t, c * f) - # now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels)) + # now x: (N, (T-7)//2, out_width * layer3_channels)) x = self.out(x) - # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + # Now x is of shape (N, (T-7)//2, odim) x = self.out_whiten(x) x = self.out_norm(x) x = self.dropout(x) @@ -328,7 +331,7 @@ class Conv2dSubsampling(nn.Module): with warnings.catch_warnings(): warnings.simplefilter("ignore") x_lens = (x_lens - 7) // 2 - assert x.size(1) == x_lens.max().item() + assert x.size(1) == x_lens.max().item() , (x.size(1), x_lens.max()) return x, x_lens @@ -347,7 +350,7 @@ class Conv2dSubsampling(nn.Module): A tensor of shape (batch_size,) containing the number of frames in Returns: - - a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + - a tensor of shape (N, (T-7)//2, odim) - output lengths, of shape (batch_size,) - updated cache """ @@ -383,7 +386,7 @@ class Conv2dSubsampling(nn.Module): assert self.convnext.padding[0] == 3 x_lens = (x_lens - 7) // 2 - 3 - assert x.size(1) == x_lens.max().item() + assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max()) return x, x_lens, cached_left_pad diff --git a/egs/librispeech/ASR/zipformer/test_scaling.py b/egs/librispeech/ASR/zipformer/test_scaling.py new file mode 100755 index 000000000..5c04291e7 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/test_scaling.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 + +import matplotlib.pyplot as plt +import torch +from scaling import PiecewiseLinear, ScheduledFloat, SwooshL, SwooshR + + +def test_piecewise_linear(): + # An identity map in the range [0, 1]. + # 1 - identity map in the range [1, 2] + # x1=0, y1=0 + # x2=1, y2=1 + # x3=2, y3=0 + pl = PiecewiseLinear((0, 0), (1, 1), (2, 0)) + assert pl(0.25) == 0.25, pl(0.25) + assert pl(0.625) == 0.625, pl(0.625) + assert pl(1.25) == 0.75, pl(1.25) + + assert pl(-10) == pl(0), pl(-10) # out of range + assert pl(10) == pl(2), pl(10) # out of range + + # multiplication + pl10 = pl * 10 + assert pl10(1) == 10 * pl(1) + assert pl10(0.5) == 10 * pl(0.5) + + +def test_scheduled_float(): + # Initial value is 0.2 and it decreases linearly towards 0 at 4000 + dropout = ScheduledFloat((0, 0.2), (4000, 0.0), default=0.0) + dropout.batch_count = 0 + assert float(dropout) == 0.2, (float(dropout), dropout.batch_count) + + dropout.batch_count = 1000 + assert abs(float(dropout) - 0.15) < 1e-5, (float(dropout), dropout.batch_count) + + dropout.batch_count = 2000 + assert float(dropout) == 0.1, (float(dropout), dropout.batch_count) + + dropout.batch_count = 3000 + assert abs(float(dropout) - 0.05) < 1e-5, (float(dropout), dropout.batch_count) + + dropout.batch_count = 4000 + assert float(dropout) == 0.0, (float(dropout), dropout.batch_count) + + dropout.batch_count = 5000 # out of range + assert float(dropout) == 0.0, (float(dropout), dropout.batch_count) + + +def test_swoosh(): + x1 = torch.linspace(start=-10, end=0, steps=100, dtype=torch.float32) + x2 = torch.linspace(start=0, end=10, steps=100, dtype=torch.float32) + x = torch.cat([x1, x2[1:]]) + + left = SwooshL()(x) + r = SwooshR()(x) + + relu = torch.nn.functional.relu(x) + print(left[x == 0], r[x == 0]) + plt.plot(x, left, "k") + plt.plot(x, r, "r") + plt.plot(x, relu, "b") + plt.axis([-10, 10, -1, 10]) # [xmin, xmax, ymin, ymax] + plt.legend( + [ + "SwooshL(x) = log(1 + exp(x-4)) - 0.08x - 0.035 ", + "SwooshR(x) = log(1 + exp(x-1)) - 0.08x - 0.313261687", + "ReLU(x) = max(0, x)", + ] + ) + plt.grid() + plt.savefig("swoosh.pdf") + + +def main(): + test_piecewise_linear() + test_scheduled_float() + test_swoosh() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer/test_subsampling.py b/egs/librispeech/ASR/zipformer/test_subsampling.py new file mode 100755 index 000000000..078227fb6 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/test_subsampling.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 + +import torch +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling + + +def test_conv2d_subsampling(): + layer1_channels = 8 + layer2_channels = 32 + layer3_channels = 128 + + out_channels = 192 + encoder_embed = Conv2dSubsampling( + in_channels=80, + out_channels=out_channels, + layer1_channels=layer1_channels, + layer2_channels=layer2_channels, + layer3_channels=layer3_channels, + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + N = 2 + T = 200 + num_features = 80 + x = torch.rand(N, T, num_features) + x_copy = x.clone() + + x = x.unsqueeze(1) # (N, 1, T, num_features) + + x = encoder_embed.conv[0](x) # conv2d, in 1, out 8, kernel 3, padding (0,1) + assert x.shape == (N, layer1_channels, T - 2, num_features) + # (2, 8, 198, 80) + + x = encoder_embed.conv[1](x) # scale grad + x = encoder_embed.conv[2](x) # balancer + x = encoder_embed.conv[3](x) # swooshR + + x = encoder_embed.conv[4](x) # conv2d, in 8, out 32, kernel 3, stride 2 + assert x.shape == ( + N, + layer2_channels, + ((T - 2) - 3) // 2 + 1, + (num_features - 3) // 2 + 1, + ) + # (2, 32, 98, 39) + + x = encoder_embed.conv[5](x) # balancer + x = encoder_embed.conv[6](x) # swooshR + + # conv2d: + # in 32, out 128, kernel 3, stride (1, 2) + x = encoder_embed.conv[7](x) + assert x.shape == ( + N, + layer3_channels, + (((T - 2) - 3) // 2 + 1) - 2, + (((num_features - 3) // 2 + 1) - 3) // 2 + 1, + ) + # (2, 128, 96, 19) + + x = encoder_embed.conv[8](x) # balancer + x = encoder_embed.conv[9](x) # swooshR + + # (((T - 2) - 3) // 2 + 1) - 2 + # = (T - 2) - 3) // 2 + 1 - 2 + # = ((T - 2) - 3) // 2 - 1 + # = (T - 2 - 3) // 2 - 1 + # = (T - 5) // 2 - 1 + # = (T - 7) // 2 + assert x.shape[2] == (x_copy.shape[1] - 7) // 2 + + # (((num_features - 3) // 2 + 1) - 3) // 2 + 1, + # = ((num_features - 3) // 2 + 1 - 3) // 2 + 1, + # = ((num_features - 3) // 2 - 2) // 2 + 1, + # = (num_features - 3 - 4) // 2 // 2 + 1, + # = (num_features - 7) // 2 // 2 + 1, + # = (num_features - 7) // 4 + 1, + # = (num_features - 3) // 4 + assert x.shape[3] == (x_copy.shape[2] - 3) // 4 + + assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) + + # Input shape to convnext is + # + # (N, layer3_channels, (T-7)//2, (num_features - 3)//4) + + # conv2d: in layer3_channels, out layer3_channels, groups layer3_channels + # kernel_size 7, padding 3 + x = encoder_embed.convnext.depthwise_conv(x) + assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) + + # conv2d: in layer3_channels, out hidden_ratio * layer3_channels, kernel_size 1 + x = encoder_embed.convnext.pointwise_conv1(x) + assert x.shape == (N, layer3_channels * 3, (T - 7) // 2, (num_features - 3) // 4) + + x = encoder_embed.convnext.hidden_balancer(x) # balancer + x = encoder_embed.convnext.activation(x) # swooshL + + # conv2d: in hidden_ratio * layer3_channels, out layer3_channels, kernel 1 + x = encoder_embed.convnext.pointwise_conv2(x) + assert x.shape == (N, layer3_channels, (T - 7) // 2, (num_features - 3) // 4) + + # bypass and layer drop, omitted here. + x = encoder_embed.convnext.out_balancer(x) + + # Note: the input and output shape of ConvNeXt are the same + + x = x.transpose(1, 2).reshape(N, (T - 7) // 2, -1) + assert x.shape == (N, (T - 7) // 2, layer3_channels * ((num_features - 3) // 4)) + + x = encoder_embed.out(x) + assert x.shape == (N, (T - 7) // 2, out_channels) + + x = encoder_embed.out_whiten(x) + x = encoder_embed.out_norm(x) + # final layer is dropout + + # test streaming forward + + subsampling_factor = 2 + cached_left_padding = encoder_embed.get_init_states(batch_size=N) + depthwise_conv_kernel_size = 7 + pad_size = (depthwise_conv_kernel_size - 1) // 2 + + assert cached_left_padding.shape == ( + N, + layer3_channels, + pad_size, + (num_features - 3) // 4, + ) + + chunk_size = 16 + right_padding = pad_size * subsampling_factor + T = chunk_size * subsampling_factor + 7 + right_padding + x = torch.rand(N, T, num_features) + x_lens = torch.tensor([T] * N) + y, y_lens, next_cached_left_padding = encoder_embed.streaming_forward( + x, x_lens, cached_left_padding + ) + + assert y.shape == (N, chunk_size, out_channels), y.shape + assert next_cached_left_padding.shape == cached_left_padding.shape + + assert y.shape[1] == y_lens[0] == y_lens[1] + + +def main(): + test_conv2d_subsampling() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7d98dbeb1..b39af02b8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -219,7 +219,7 @@ class Zipformer2(EncoderInterface): (num_frames0, batch_size, _encoder_dims0) = x.shape - assert self.encoder_dim[0] == _encoder_dims0 + assert self.encoder_dim[0] == _encoder_dims0, (self.encoder_dim[0], _encoder_dims0) feature_mask_dropout_prob = 0.125 @@ -334,7 +334,7 @@ class Zipformer2(EncoderInterface): x = self._get_full_dim_output(outputs) x = self.downsample_output(x) # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2 + assert self.output_downsampling_factor == 2, self.output_downsampling_factor if torch.jit.is_scripting() or torch.jit.is_tracing(): lengths = (x_lens + 1) // 2 else: From 80d922c1583b9b7fb7e9b47008302cdc74ef58b7 Mon Sep 17 00:00:00 2001 From: kobenaxie <572745565@qq.com> Date: Wed, 26 Jul 2023 16:54:42 +0800 Subject: [PATCH 007/113] Update preprocess_commonvoice.py to fix text normalization bug. (#1181) --- egs/commonvoice/ASR/local/preprocess_commonvoice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py index c5ec14502..e60459765 100755 --- a/egs/commonvoice/ASR/local/preprocess_commonvoice.py +++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py @@ -45,7 +45,7 @@ def get_args(): def normalize_text(utt: str) -> str: utt = re.sub(r"[{0}]+".format("-"), " ", utt) - return re.sub(r"[^a-zA-Z\s]", "", utt).upper() + return re.sub(r"[^a-zA-Z\s']", "", utt).upper() def preprocess_commonvoice( From 625b33e9ad15961239ea77d12472428d8006085d Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 27 Jul 2023 12:08:20 +0800 Subject: [PATCH 008/113] Update descriptions for different decoding methods with external LMs (#1185) * add some descriptions * minor updates --- .../decoding-with-langugage-models/index.rst | 21 +++++++++++++++++++ .../rescoring.rst | 14 ++++++++----- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/docs/source/decoding-with-langugage-models/index.rst b/docs/source/decoding-with-langugage-models/index.rst index 577ebbdfb..6e5e3a4d9 100644 --- a/docs/source/decoding-with-langugage-models/index.rst +++ b/docs/source/decoding-with-langugage-models/index.rst @@ -4,6 +4,27 @@ Decoding with language models This section describes how to use external langugage models during decoding to improve the WER of transducer models. +The following decoding methods with external langugage models are available: + + +.. list-table:: LM-rescoring-based methods vs shallow-fusion-based methods (The numbers in each field is WER on test-clean, WER on test-other and decoding time on test-clean) + :widths: 25 50 + :header-rows: 1 + + * - Decoding method + - beam=4 + * - ``modified_beam_search`` + - Beam search (i.e. really n-best decoding, the "beam" is the value of n), similar to the original RNN-T paper. Note, this method does not use language model. + * - ``modified_beam_search_lm_shallow_fusion`` + - As ``modified_beam_search``, but interpolate RNN-T scores with language model scores, also known as shallow fusion + * - ``modified_beam_search_LODR`` + - As ``modified_beam_search_lm_shallow_fusion``, but subtract score of a (BPE-symbol-level) bigram backoff language model used as an approximation to the internal language model of RNN-T. + * - ``modified_beam_search_lm_rescore`` + - As ``modified_beam_search``, but rescore the n-best hypotheses with external language model (e.g. RNNLM) and re-rank them. + * - ``modified_beam_search_lm_rescore_LODR`` + - As ``modified_beam_search_lm_rescore``, but also subtract the score of a (BPE-symbol-level) bigram backoff language model during re-ranking. + + .. toctree:: :maxdepth: 2 diff --git a/docs/source/decoding-with-langugage-models/rescoring.rst b/docs/source/decoding-with-langugage-models/rescoring.rst index d71acc1e5..de7e700d0 100644 --- a/docs/source/decoding-with-langugage-models/rescoring.rst +++ b/docs/source/decoding-with-langugage-models/rescoring.rst @@ -4,7 +4,11 @@ LM rescoring for Transducer ================================= LM rescoring is a commonly used approach to incorporate external LM information. Unlike shallow-fusion-based +<<<<<<< HEAD +methods (see :ref:`shallow_fusion`, :ref:`LODR`), rescoring is usually performed to re-rank the n-best hypotheses after beam search. +======= methods (see :ref:`shallow-fusion`, :ref:`LODR`), rescoring is usually performed to re-rank the n-best hypotheses after beam search. +>>>>>>> 80d922c1583b9b7fb7e9b47008302cdc74ef58b7 Rescoring is usually more efficient than shallow fusion since less computation is performed on the external LM. In this tutorial, we will show you how to use external LM to rescore the n-best hypotheses decoded from neural transducer models in `icefall `__. @@ -225,23 +229,23 @@ Here, we benchmark the WERs and decoding speed of them: - beam=4 - beam=8 - beam=12 - * - `modified_beam_search` + * - ``modified_beam_search`` - 3.11/7.93; 132s - 3.1/7.95; 177s - 3.1/7.96; 210s - * - `modified_beam_search_lm_shallow_fusion` + * - ``modified_beam_search_lm_shallow_fusion`` - 2.77/7.08; 262s - 2.62/6.65; 352s - 2.58/6.65; 488s - * - LODR + * - ``modified_beam_search_LODR`` - 2.61/6.74; 400s - 2.45/6.38; 610s - 2.4/6.23; 870s - * - `modified_beam_search_lm_rescore` + * - ``modified_beam_search_lm_rescore`` - 2.93/7.6; 156s - 2.67/7.11; 203s - 2.59/6.86; 255s - * - `modified_beam_search_lm_rescore_LODR` + * - ``modified_beam_search_lm_rescore_LODR`` - 2.9/7.57; 160s - 2.63/7.04; 203s - 2.52/6.73; 263s From 3fb0a431704a18c9d04230b07a1d75b7ea159970 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 27 Jul 2023 12:36:05 +0800 Subject: [PATCH 009/113] Fix conflict (#1187) Resolve conflict --- docs/source/decoding-with-langugage-models/rescoring.rst | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/source/decoding-with-langugage-models/rescoring.rst b/docs/source/decoding-with-langugage-models/rescoring.rst index de7e700d0..ee2e2113c 100644 --- a/docs/source/decoding-with-langugage-models/rescoring.rst +++ b/docs/source/decoding-with-langugage-models/rescoring.rst @@ -4,11 +4,7 @@ LM rescoring for Transducer ================================= LM rescoring is a commonly used approach to incorporate external LM information. Unlike shallow-fusion-based -<<<<<<< HEAD methods (see :ref:`shallow_fusion`, :ref:`LODR`), rescoring is usually performed to re-rank the n-best hypotheses after beam search. -======= -methods (see :ref:`shallow-fusion`, :ref:`LODR`), rescoring is usually performed to re-rank the n-best hypotheses after beam search. ->>>>>>> 80d922c1583b9b7fb7e9b47008302cdc74ef58b7 Rescoring is usually more efficient than shallow fusion since less computation is performed on the external LM. In this tutorial, we will show you how to use external LM to rescore the n-best hypotheses decoded from neural transducer models in `icefall `__. From 19b942c958cba13a78757c9f7a287f8c88460bd0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 27 Jul 2023 13:36:46 +0800 Subject: [PATCH 010/113] Update installation doc. (#1188) --- docs/source/conf.py | 5 + docs/source/installation/index.rst | 687 +++++++++++++++-------------- 2 files changed, 354 insertions(+), 338 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 0ff3f801c..bf231e3c1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -90,4 +90,9 @@ rst_epilog = """ .. _musan: http://www.openslr.org/17/ .. _ONNX: https://github.com/onnx/onnx .. _onnxruntime: https://github.com/microsoft/onnxruntime +.. _torch: https://github.com/pytorch/pytorch +.. _torchaudio: https://github.com/pytorch/audio +.. _k2: https://github.com/k2-fsa/k2 +.. _lhotse: https://github.com/lhotse-speech/lhotse +.. _yesno: https://www.openslr.org/1/ """ diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index 738b24ab2..534b674f9 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -3,40 +3,23 @@ Installation ============ +.. hint:: + We have a colab notebook guiding you step by step to setup the environment. -``icefall`` depends on `k2 `_ and -`lhotse `_. + |yesno colab notebook| + + .. |yesno colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg + :target: https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing + +`icefall`_ depends on `k2`_ and `lhotse`_. We recommend that you use the following steps to install the dependencies. - (0) Install CUDA toolkit and cuDNN -- (1) Install PyTorch and torchaudio -- (2) Install k2 -- (3) Install lhotse - -.. caution:: - - 99% users who have issues about the installation are using conda. - -.. caution:: - - 99% users who have issues about the installation are using conda. - -.. caution:: - - 99% users who have issues about the installation are using conda. - -.. hint:: - - We suggest that you use ``pip install`` to install PyTorch. - - You can use the following command to create a virutal environment in Python: - - .. code-block:: bash - - python3 -m venv ./my_env - source ./my_env/bin/activate +- (1) Install `torch`_ and `torchaudio`_ +- (2) Install `k2`_ +- (3) Install `lhotse`_ .. caution:: @@ -50,27 +33,20 @@ Please refer to to install CUDA and cuDNN. -(1) Install PyTorch and torchaudio ----------------------------------- +(1) Install torch and torchaudio +-------------------------------- -Please refer ``_ to install PyTorch -and torchaudio. - -.. hint:: - - You can also go to ``_ - to download pre-compiled wheels and install them. +Please refer ``_ to install `torch`_ and `torchaudio`_. .. caution:: Please install torch and torchaudio at the same time. - (2) Install k2 -------------- Please refer to ``_ -to install ``k2``. +to install `k2`_. .. caution:: @@ -78,21 +54,18 @@ to install ``k2``. .. note:: - We suggest that you install k2 from source by following - ``_ - or - ``_. + We suggest that you install k2 from pre-compiled wheels by following + ``_ .. hint:: - Please always install the latest version of k2. + Please always install the latest version of `k2`_. (3) Install lhotse ------------------ Please refer to ``_ -to install ``lhotse``. - +to install `lhotse`_. .. hint:: @@ -100,17 +73,16 @@ to install ``lhotse``. pip install git+https://github.com/lhotse-speech/lhotse - to install the latest version of lhotse. + to install the latest version of `lhotse`_. (4) Download icefall -------------------- -``icefall`` is a collection of Python scripts; what you need is to download it +`icefall`_ is a collection of Python scripts; what you need is to download it and set the environment variable ``PYTHONPATH`` to point to it. -Assume you want to place ``icefall`` in the folder ``/tmp``. The -following commands show you how to setup ``icefall``: - +Assume you want to place `icefall`_ in the folder ``/tmp``. The +following commands show you how to setup `icefall`_: .. code-block:: bash @@ -122,285 +94,334 @@ following commands show you how to setup ``icefall``: .. HINT:: - You can put several versions of ``icefall`` in the same virtual environment. - To switch among different versions of ``icefall``, just set ``PYTHONPATH`` + You can put several versions of `icefall`_ in the same virtual environment. + To switch among different versions of `icefall`_, just set ``PYTHONPATH`` to point to the version you want. - Installation example -------------------- The following shows an example about setting up the environment. - (1) Create a virtual environment ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: bash - $ virtualenv -p python3.8 test-icefall + kuangfangjun:~$ virtualenv -p python3.8 test-icefall + created virtual environment CPython3.8.0.final.0-64 in 9422ms + creator CPython3Posix(dest=/star-fj/fangjun/test-icefall, clear=False, no_vcs_ignore=False, global=False) + seeder FromAppData(download=False, pip=bundle, setuptools=bundle, wheel=bundle, via=copy, app_data_dir=/star-fj/fangjun/.local/share/virtualenv) + added seed packages: pip==22.3.1, setuptools==65.6.3, wheel==0.38.4 + activators BashActivator,CShellActivator,FishActivator,NushellActivator,PowerShellActivator,PythonActivator - created virtual environment CPython3.8.6.final.0-64 in 1540ms - creator CPython3Posix(dest=/ceph-fj/fangjun/test-icefall, clear=False, no_vcs_ignore=False, global=False) - seeder FromAppData(download=False, pip=bundle, setuptools=bundle, wheel=bundle, via=copy, app_data_dir=/root/fangjun/.local/share/v - irtualenv) - added seed packages: pip==21.1.3, setuptools==57.4.0, wheel==0.36.2 - activators BashActivator,CShellActivator,FishActivator,PowerShellActivator,PythonActivator,XonshActivator + kuangfangjun:~$ source test-icefall/bin/activate + (test-icefall) kuangfangjun:~$ -(2) Activate your virtual environment -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +(2) Install CUDA toolkit and cuDNN +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You need to determine the version of CUDA toolkit to install. .. code-block:: bash - $ source test-icefall/bin/activate + (test-icefall) kuangfangjun:~$ nvidia-smi | head -n 4 -(3) Install k2 + Wed Jul 26 21:57:49 2023 + +-----------------------------------------------------------------------------+ + | NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 | + |-------------------------------+----------------------+----------------------+ + +You can choose any CUDA version that is ``not`` greater than the version printed by ``nvidia-smi``. +In our case, we can choose any version ``<= 11.6``. + +We will use ``CUDA 11.6`` in this example. Please follow +``_ +to install CUDA toolkit and cuDNN if you have not done that before. + +After installing CUDA toolkit, you can use the following command to verify it: + +.. code-block:: bash + + (test-icefall) kuangfangjun:~$ nvcc --version + + nvcc: NVIDIA (R) Cuda compiler driver + Copyright (c) 2005-2019 NVIDIA Corporation + Built on Wed_Oct_23_19:24:38_PDT_2019 + Cuda compilation tools, release 10.2, V10.2.89 + +(3) Install torch and torchaudio +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Since we have selected CUDA toolkit ``11.6``, we have to install a version of `torch`_ +that is compiled against CUDA ``11.6``. We select ``torch 1.13.0+cu116`` in this +example. + +After selecting the version of `torch`_ to install, we need to also install +a compatible version of `torchaudio`_, which is ``0.13.0+cu116`` in our case. + +Please refer to ``_ +to select an appropriate version of `torchaudio`_ to install if you use a different +version of `torch`_. + +.. code-block:: bash + + (test-icefall) kuangfangjun:~$ pip install torch==1.13.0+cu116 torchaudio==0.13.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html + + Looking in links: https://download.pytorch.org/whl/torch_stable.html + Collecting torch==1.13.0+cu116 + Downloading https://download.pytorch.org/whl/cu116/torch-1.13.0%2Bcu116-cp38-cp38-linux_x86_64.whl (1983.0 MB) + ________________________________________ 2.0/2.0 GB 764.4 kB/s eta 0:00:00 + Collecting torchaudio==0.13.0+cu116 + Downloading https://download.pytorch.org/whl/cu116/torchaudio-0.13.0%2Bcu116-cp38-cp38-linux_x86_64.whl (4.2 MB) + ________________________________________ 4.2/4.2 MB 1.3 MB/s eta 0:00:00 + Requirement already satisfied: typing-extensions in /star-fj/fangjun/test-icefall/lib/python3.8/site-packages (from torch==1.13.0+cu116) (4.7.1) + Installing collected packages: torch, torchaudio + Successfully installed torch-1.13.0+cu116 torchaudio-0.13.0+cu116 + +Verify that `torch`_ and `torchaudio`_ are successfully installed: + +.. code-block:: bash + + (test-icefall) kuangfangjun:~$ python3 -c "import torch; print(torch.__version__)" + + 1.13.0+cu116 + + (test-icefall) kuangfangjun:~$ python3 -c "import torchaudio; print(torchaudio.__version__)" + + 0.13.0+cu116 + +(4) Install k2 ~~~~~~~~~~~~~~ +We will install `k2`_ from pre-compiled wheels by following +``_ + .. code-block:: bash - $ pip install k2==1.4.dev20210822+cpu.torch1.9.0 -f https://k2-fsa.org/nightly/index.html + (test-icefall) kuangfangjun:~$ pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda.html - Looking in links: https://k2-fsa.org/nightly/index.html - Collecting k2==1.4.dev20210822+cpu.torch1.9.0 - Downloading https://k2-fsa.org/nightly/whl/k2-1.4.dev20210822%2Bcpu.torch1.9.0-cp38-cp38-linux_x86_64.whl (1.6 MB) - |________________________________| 1.6 MB 185 kB/s + Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple + Looking in links: https://k2-fsa.github.io/k2/cuda.html + Collecting k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 + Downloading https://huggingface.co/csukuangfj/k2/resolve/main/ubuntu-cuda/k2-1.24.3.dev20230725%2Bcuda11.6.torch1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (104.3 MB) + ________________________________________ 104.3/104.3 MB 5.1 MB/s eta 0:00:00 + Requirement already satisfied: torch==1.13.0 in /star-fj/fangjun/test-icefall/lib/python3.8/site-packages (from k2==1.24.3.dev20230725+cuda11.6.torch1.13.0) (1.13.0+cu116) Collecting graphviz - Downloading graphviz-0.17-py3-none-any.whl (18 kB) - Collecting torch==1.9.0 - Using cached torch-1.9.0-cp38-cp38-manylinux1_x86_64.whl (831.4 MB) - Collecting typing-extensions - Using cached typing_extensions-3.10.0.0-py3-none-any.whl (26 kB) - Installing collected packages: typing-extensions, torch, graphviz, k2 - Successfully installed graphviz-0.17 k2-1.4.dev20210822+cpu.torch1.9.0 torch-1.9.0 typing-extensions-3.10.0.0 + Using cached https://pypi.tuna.tsinghua.edu.cn/packages/de/5e/fcbb22c68208d39edff467809d06c9d81d7d27426460ebc598e55130c1aa/graphviz-0.20.1-py3-none-any.whl (47 kB) + Requirement already satisfied: typing-extensions in /star-fj/fangjun/test-icefall/lib/python3.8/site-packages (from torch==1.13.0->k2==1.24.3.dev20230725+cuda11.6.torch1.13.0) (4.7.1) + Installing collected packages: graphviz, k2 + Successfully installed graphviz-0.20.1 k2-1.24.3.dev20230725+cuda11.6.torch1.13.0 -.. WARNING:: +.. hint:: - We choose to install a CPU version of k2 for testing. You would probably want to install - a CUDA version of k2. + Please refer to ``_ for the available + pre-compiled wheels about `k2`_. +Verify that `k2`_ has been installed successfully: -(4) Install lhotse +.. code-block:: bash + + (test-icefall) kuangfangjun:~$ python3 -m k2.version + + Collecting environment information... + + k2 version: 1.24.3 + Build type: Release + Git SHA1: 4c05309499a08454997adf500b56dcc629e35ae5 + Git date: Tue Jul 25 16:23:36 2023 + Cuda used to build k2: 11.6 + cuDNN used to build k2: 8.3.2 + Python version used to build k2: 3.8 + OS used to build k2: CentOS Linux release 7.9.2009 (Core) + CMake version: 3.27.0 + GCC version: 9.3.1 + CMAKE_CUDA_FLAGS: -Wno-deprecated-gpu-targets -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_50,code=sm_50 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_60,code=sm_60 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_61,code=sm_61 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_70,code=sm_70 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_75,code=sm_75 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_80,code=sm_80 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_86,code=sm_86 -DONNX_NAMESPACE=onnx_c2 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_86,code=compute_86 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=integer_sign_change,--diag_suppress=useless_using_declaration,--diag_suppress=set_but_not_used,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=implicit_return_from_non_void_function,--diag_suppress=unsigned_compare_with_zero,--diag_suppress=declared_but_not_referenced,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -D_GLIBCXX_USE_CXX11_ABI=0 --compiler-options -Wall --compiler-options -Wno-strict-overflow --compiler-options -Wno-unknown-pragmas + CMAKE_CXX_FLAGS: -D_GLIBCXX_USE_CXX11_ABI=0 -Wno-unused-variable -Wno-strict-overflow + PyTorch version used to build k2: 1.13.0+cu116 + PyTorch is using Cuda: 11.6 + NVTX enabled: True + With CUDA: True + Disable debug: True + Sync kernels : False + Disable checks: False + Max cpu memory allocate: 214748364800 bytes (or 200.0 GB) + k2 abort: False + __file__: /star-fj/fangjun/test-icefall/lib/python3.8/site-packages/k2/version/version.py + _k2.__file__: /star-fj/fangjun/test-icefall/lib/python3.8/site-packages/_k2.cpython-38-x86_64-linux-gnu.so + +(5) Install lhotse ~~~~~~~~~~~~~~~~~~ -.. code-block:: +.. code-block:: bash - $ pip install git+https://github.com/lhotse-speech/lhotse + (test-icefall) kuangfangjun:~$ pip install git+https://github.com/lhotse-speech/lhotse Collecting git+https://github.com/lhotse-speech/lhotse - Cloning https://github.com/lhotse-speech/lhotse to /tmp/pip-req-build-7b1b76ge - Running command git clone -q https://github.com/lhotse-speech/lhotse /tmp/pip-req-build-7b1b76ge - Collecting audioread>=2.1.9 - Using cached audioread-2.1.9-py3-none-any.whl - Collecting SoundFile>=0.10 - Using cached SoundFile-0.10.3.post1-py2.py3-none-any.whl (21 kB) - Collecting click>=7.1.1 - Using cached click-8.0.1-py3-none-any.whl (97 kB) + Cloning https://github.com/lhotse-speech/lhotse to /tmp/pip-req-build-vq12fd5i + Running command git clone --filter=blob:none --quiet https://github.com/lhotse-speech/lhotse /tmp/pip-req-build-vq12fd5i + Resolved https://github.com/lhotse-speech/lhotse to commit 7640d663469b22cd0b36f3246ee9b849cd25e3b7 + Installing build dependencies ... done + Getting requirements to build wheel ... done + Preparing metadata (pyproject.toml) ... done Collecting cytoolz>=0.10.1 - Using cached cytoolz-0.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.9 MB) - Collecting dataclasses - Using cached dataclasses-0.6-py3-none-any.whl (14 kB) - Collecting h5py>=2.10.0 - Downloading h5py-3.4.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (4.5 MB) - |________________________________| 4.5 MB 684 kB/s - Collecting intervaltree>=3.1.0 - Using cached intervaltree-3.1.0-py2.py3-none-any.whl - Collecting lilcom>=1.1.0 - Using cached lilcom-1.1.1-cp38-cp38-linux_x86_64.whl - Collecting numpy>=1.18.1 - Using cached numpy-1.21.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.8 MB) - Collecting packaging - Using cached packaging-21.0-py3-none-any.whl (40 kB) + Downloading https://pypi.tuna.tsinghua.edu.cn/packages/1e/3b/a7828d575aa17fb7acaf1ced49a3655aa36dad7e16eb7e6a2e4df0dda76f/cytoolz-0.12.2-cp38-cp38- + manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB) + ________________________________________ 2.0/2.0 MB 33.2 MB/s eta 0:00:00 Collecting pyyaml>=5.3.1 - Using cached PyYAML-5.4.1-cp38-cp38-manylinux1_x86_64.whl (662 kB) + Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c8/6b/6600ac24725c7388255b2f5add93f91e58a5d7efaf4af244fdbcc11a541b/PyYAML-6.0.1-cp38-cp38-ma + nylinux_2_17_x86_64.manylinux2014_x86_64.whl (736 kB) + ________________________________________ 736.6/736.6 kB 38.6 MB/s eta 0:00:00 + Collecting dataclasses + Downloading https://pypi.tuna.tsinghua.edu.cn/packages/26/2f/1095cdc2868052dd1e64520f7c0d5c8c550ad297e944e641dbf1ffbb9a5d/dataclasses-0.6-py3-none- + any.whl (14 kB) + Requirement already satisfied: torchaudio in ./test-icefall/lib/python3.8/site-packages (from lhotse==1.16.0.dev0+git.7640d66.clean) (0.13.0+cu116) + Collecting lilcom>=1.1.0 + Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a8/65/df0a69c52bd085ca1ad4e5c4c1a5c680e25f9477d8e49316c4ff1e5084a4/lilcom-1.7-cp38-cp38-many + linux_2_17_x86_64.manylinux2014_x86_64.whl (87 kB) + ________________________________________ 87.1/87.1 kB 8.7 MB/s eta 0:00:00 Collecting tqdm - Downloading tqdm-4.62.1-py2.py3-none-any.whl (76 kB) - |________________________________| 76 kB 2.7 MB/s - Collecting torchaudio==0.9.0 - Downloading torchaudio-0.9.0-cp38-cp38-manylinux1_x86_64.whl (1.9 MB) - |________________________________| 1.9 MB 73.1 MB/s - Requirement already satisfied: torch==1.9.0 in ./test-icefall/lib/python3.8/site-packages (from torchaudio==0.9.0->lhotse===0.8.0.dev - -2a1410b-clean) (1.9.0) - Requirement already satisfied: typing-extensions in ./test-icefall/lib/python3.8/site-packages (from torch==1.9.0->torchaudio==0.9.0- - >lhotse===0.8.0.dev-2a1410b-clean) (3.10.0.0) + Using cached https://pypi.tuna.tsinghua.edu.cn/packages/e6/02/a2cff6306177ae6bc73bc0665065de51dfb3b9db7373e122e2735faf0d97/tqdm-4.65.0-py3-none-any + .whl (77 kB) + Requirement already satisfied: numpy>=1.18.1 in ./test-icefall/lib/python3.8/site-packages (from lhotse==1.16.0.dev0+git.7640d66.clean) (1.24.4) + Collecting audioread>=2.1.9 + Using cached https://pypi.tuna.tsinghua.edu.cn/packages/5d/cb/82a002441902dccbe427406785db07af10182245ee639ea9f4d92907c923/audioread-3.0.0.tar.gz ( + 377 kB) + Preparing metadata (setup.py) ... done + Collecting tabulate>=0.8.1 + Using cached https://pypi.tuna.tsinghua.edu.cn/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none- + any.whl (35 kB) + Collecting click>=7.1.1 + Downloading https://pypi.tuna.tsinghua.edu.cn/packages/1a/70/e63223f8116931d365993d4a6b7ef653a4d920b41d03de7c59499962821f/click-8.1.6-py3-none-any. + whl (97 kB) + ________________________________________ 97.9/97.9 kB 8.4 MB/s eta 0:00:00 + Collecting packaging + Using cached https://pypi.tuna.tsinghua.edu.cn/packages/ab/c3/57f0601a2d4fe15de7a553c00adbc901425661bf048f2a22dfc500caf121/packaging-23.1-py3-none- + any.whl (48 kB) + Collecting intervaltree>=3.1.0 + Downloading https://pypi.tuna.tsinghua.edu.cn/packages/50/fb/396d568039d21344639db96d940d40eb62befe704ef849b27949ded5c3bb/intervaltree-3.1.0.tar.gz + (32 kB) + Preparing metadata (setup.py) ... done + Requirement already satisfied: torch in ./test-icefall/lib/python3.8/site-packages (from lhotse==1.16.0.dev0+git.7640d66.clean) (1.13.0+cu116) + Collecting SoundFile>=0.10 + Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ad/bd/0602167a213d9184fc688b1086dc6d374b7ae8c33eccf169f9b50ce6568c/soundfile-0.12.1-py2.py3- + none-manylinux_2_17_x86_64.whl (1.3 MB) + ________________________________________ 1.3/1.3 MB 46.5 MB/s eta 0:00:00 Collecting toolz>=0.8.0 - Using cached toolz-0.11.1-py3-none-any.whl (55 kB) + Using cached https://pypi.tuna.tsinghua.edu.cn/packages/7f/5c/922a3508f5bda2892be3df86c74f9cf1e01217c2b1f8a0ac4841d903e3e9/toolz-0.12.0-py3-none-any.whl (55 kB) Collecting sortedcontainers<3.0,>=2.0 - Using cached sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB) + Using cached https://pypi.tuna.tsinghua.edu.cn/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB) Collecting cffi>=1.0 - Using cached cffi-1.14.6-cp38-cp38-manylinux1_x86_64.whl (411 kB) + Using cached https://pypi.tuna.tsinghua.edu.cn/packages/b7/8b/06f30caa03b5b3ac006de4f93478dbd0239e2a16566d81a106c322dc4f79/cffi-1.15.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (442 kB) + Requirement already satisfied: typing-extensions in ./test-icefall/lib/python3.8/site-packages (from torch->lhotse==1.16.0.dev0+git.7640d66.clean) (4.7.1) Collecting pycparser - Using cached pycparser-2.20-py2.py3-none-any.whl (112 kB) - Collecting pyparsing>=2.0.2 - Using cached pyparsing-2.4.7-py2.py3-none-any.whl (67 kB) - Building wheels for collected packages: lhotse - Building wheel for lhotse (setup.py) ... done - Created wheel for lhotse: filename=lhotse-0.8.0.dev_2a1410b_clean-py3-none-any.whl size=342242 sha256=f683444afa4dc0881133206b4646a - 9d0f774224cc84000f55d0a67f6e4a37997 - Stored in directory: /tmp/pip-ephem-wheel-cache-ftu0qysz/wheels/7f/7a/8e/a0bf241336e2e3cb573e1e21e5600952d49f5162454f2e612f - WARNING: Built wheel for lhotse is invalid: Metadata 1.2 mandates PEP 440 version, but '0.8.0.dev-2a1410b-clean' is not - Failed to build lhotse - Installing collected packages: pycparser, toolz, sortedcontainers, pyparsing, numpy, cffi, tqdm, torchaudio, SoundFile, pyyaml, packa - ging, lilcom, intervaltree, h5py, dataclasses, cytoolz, click, audioread, lhotse - Running setup.py install for lhotse ... done - DEPRECATION: lhotse was installed using the legacy 'setup.py install' method, because a wheel could not be built for it. A possible - replacement is to fix the wheel build issue reported above. You can find discussion regarding this at https://github.com/pypa/pip/is - sues/8368. - Successfully installed SoundFile-0.10.3.post1 audioread-2.1.9 cffi-1.14.6 click-8.0.1 cytoolz-0.11.0 dataclasses-0.6 h5py-3.4.0 inter - valtree-3.1.0 lhotse-0.8.0.dev-2a1410b-clean lilcom-1.1.1 numpy-1.21.2 packaging-21.0 pycparser-2.20 pyparsing-2.4.7 pyyaml-5.4.1 sor - tedcontainers-2.4.0 toolz-0.11.1 torchaudio-0.9.0 tqdm-4.62.1 + Using cached https://pypi.tuna.tsinghua.edu.cn/packages/62/d5/5f610ebe421e85889f2e55e33b7f9a6795bd982198517d912eb1c76e1a53/pycparser-2.21-py2.py3-none-any.whl (118 kB) + Building wheels for collected packages: lhotse, audioread, intervaltree + Building wheel for lhotse (pyproject.toml) ... done + Created wheel for lhotse: filename=lhotse-1.16.0.dev0+git.7640d66.clean-py3-none-any.whl size=687627 sha256=cbf0a4d2d0b639b33b91637a4175bc251d6a021a069644ecb1a9f2b3a83d072a + Stored in directory: /tmp/pip-ephem-wheel-cache-wwtk90_m/wheels/7f/7a/8e/a0bf241336e2e3cb573e1e21e5600952d49f5162454f2e612f + Building wheel for audioread (setup.py) ... done + Created wheel for audioread: filename=audioread-3.0.0-py3-none-any.whl size=23704 sha256=5e2d3537c96ce9cf0f645a654c671163707bf8cb8d9e358d0e2b0939a85ff4c2 + Stored in directory: /star-fj/fangjun/.cache/pip/wheels/e2/c3/9c/f19ae5a03f8862d9f0776b0c0570f1fdd60a119d90954e3f39 + Building wheel for intervaltree (setup.py) ... done + Created wheel for intervaltree: filename=intervaltree-3.1.0-py2.py3-none-any.whl size=26098 sha256=2604170976cfffe0d2f678cb1a6e5b525f561cd50babe53d631a186734fec9f9 + Stored in directory: /star-fj/fangjun/.cache/pip/wheels/f3/ed/2b/c179ebfad4e15452d6baef59737f27beb9bfb442e0620f7271 + Successfully built lhotse audioread intervaltree + Installing collected packages: sortedcontainers, dataclasses, tqdm, toolz, tabulate, pyyaml, pycparser, packaging, lilcom, intervaltree, click, audioread, cytoolz, cffi, SoundFile, lhotse + Successfully installed SoundFile-0.12.1 audioread-3.0.0 cffi-1.15.1 click-8.1.6 cytoolz-0.12.2 dataclasses-0.6 intervaltree-3.1.0 lhotse-1.16.0.dev0+git.7640d66.clean lilcom-1.7 packaging-23.1 pycparser-2.21 pyyaml-6.0.1 sortedcontainers-2.4.0 tabulate-0.9.0 toolz-0.12.0 tqdm-4.65.0 -(5) Download icefall + +Verify that `lhotse`_ has been installed successfully: + +.. code-block:: bash + + (test-icefall) kuangfangjun:~$ python3 -c "import lhotse; print(lhotse.__version__)" + + 1.16.0.dev+git.7640d66.clean + +(6) Download icefall ~~~~~~~~~~~~~~~~~~~~ -.. code-block:: +.. code-block:: bash - $ cd /tmp - $ git clone https://github.com/k2-fsa/icefall + (test-icefall) kuangfangjun:~$ cd /tmp/ + + (test-icefall) kuangfangjun:tmp$ git clone https://github.com/k2-fsa/icefall Cloning into 'icefall'... - remote: Enumerating objects: 500, done. - remote: Counting objects: 100% (500/500), done. - remote: Compressing objects: 100% (308/308), done. - remote: Total 500 (delta 263), reused 307 (delta 102), pack-reused 0 - Receiving objects: 100% (500/500), 172.49 KiB | 385.00 KiB/s, done. - Resolving deltas: 100% (263/263), done. + remote: Enumerating objects: 12942, done. + remote: Counting objects: 100% (67/67), done. + remote: Compressing objects: 100% (56/56), done. + remote: Total 12942 (delta 17), reused 35 (delta 6), pack-reused 12875 + Receiving objects: 100% (12942/12942), 14.77 MiB | 9.29 MiB/s, done. + Resolving deltas: 100% (8835/8835), done. - $ cd icefall - $ pip install -r requirements.txt - - Collecting kaldilm - Downloading kaldilm-1.8.tar.gz (48 kB) - |________________________________| 48 kB 574 kB/s - Collecting kaldialign - Using cached kaldialign-0.2-cp38-cp38-linux_x86_64.whl - Collecting sentencepiece>=0.1.96 - Using cached sentencepiece-0.1.96-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB) - Collecting tensorboard - Using cached tensorboard-2.6.0-py3-none-any.whl (5.6 MB) - Requirement already satisfied: setuptools>=41.0.0 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r - requirements.txt (line 4)) (57.4.0) - Collecting absl-py>=0.4 - Using cached absl_py-0.13.0-py3-none-any.whl (132 kB) - Collecting google-auth-oauthlib<0.5,>=0.4.1 - Using cached google_auth_oauthlib-0.4.5-py2.py3-none-any.whl (18 kB) - Collecting grpcio>=1.24.3 - Using cached grpcio-1.39.0-cp38-cp38-manylinux2014_x86_64.whl (4.3 MB) - Requirement already satisfied: wheel>=0.26 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r require - ments.txt (line 4)) (0.36.2) - Requirement already satisfied: numpy>=1.12.0 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r requi - rements.txt (line 4)) (1.21.2) - Collecting protobuf>=3.6.0 - Using cached protobuf-3.17.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB) - Collecting werkzeug>=0.11.15 - Using cached Werkzeug-2.0.1-py3-none-any.whl (288 kB) - Collecting tensorboard-data-server<0.7.0,>=0.6.0 - Using cached tensorboard_data_server-0.6.1-py3-none-manylinux2010_x86_64.whl (4.9 MB) - Collecting google-auth<2,>=1.6.3 - Downloading google_auth-1.35.0-py2.py3-none-any.whl (152 kB) - |________________________________| 152 kB 1.4 MB/s - Collecting requests<3,>=2.21.0 - Using cached requests-2.26.0-py2.py3-none-any.whl (62 kB) - Collecting tensorboard-plugin-wit>=1.6.0 - Using cached tensorboard_plugin_wit-1.8.0-py3-none-any.whl (781 kB) - Collecting markdown>=2.6.8 - Using cached Markdown-3.3.4-py3-none-any.whl (97 kB) - Collecting six - Using cached six-1.16.0-py2.py3-none-any.whl (11 kB) - Collecting cachetools<5.0,>=2.0.0 - Using cached cachetools-4.2.2-py3-none-any.whl (11 kB) - Collecting rsa<5,>=3.1.4 - Using cached rsa-4.7.2-py3-none-any.whl (34 kB) - Collecting pyasn1-modules>=0.2.1 - Using cached pyasn1_modules-0.2.8-py2.py3-none-any.whl (155 kB) - Collecting requests-oauthlib>=0.7.0 - Using cached requests_oauthlib-1.3.0-py2.py3-none-any.whl (23 kB) - Collecting pyasn1<0.5.0,>=0.4.6 - Using cached pyasn1-0.4.8-py2.py3-none-any.whl (77 kB) - Collecting urllib3<1.27,>=1.21.1 - Using cached urllib3-1.26.6-py2.py3-none-any.whl (138 kB) - Collecting certifi>=2017.4.17 - Using cached certifi-2021.5.30-py2.py3-none-any.whl (145 kB) - Collecting charset-normalizer~=2.0.0 - Using cached charset_normalizer-2.0.4-py3-none-any.whl (36 kB) - Collecting idna<4,>=2.5 - Using cached idna-3.2-py3-none-any.whl (59 kB) - Collecting oauthlib>=3.0.0 - Using cached oauthlib-3.1.1-py2.py3-none-any.whl (146 kB) - Building wheels for collected packages: kaldilm - Building wheel for kaldilm (setup.py) ... done - Created wheel for kaldilm: filename=kaldilm-1.8-cp38-cp38-linux_x86_64.whl size=897233 sha256=eccb906cafcd45bf9a7e1a1718e4534254bfb - f4c0d0cbc66eee6c88d68a63862 - Stored in directory: /root/fangjun/.cache/pip/wheels/85/7d/63/f2dd586369b8797cb36d213bf3a84a789eeb92db93d2e723c9 - Successfully built kaldilm - Installing collected packages: urllib3, pyasn1, idna, charset-normalizer, certifi, six, rsa, requests, pyasn1-modules, oauthlib, cach - etools, requests-oauthlib, google-auth, werkzeug, tensorboard-plugin-wit, tensorboard-data-server, protobuf, markdown, grpcio, google - -auth-oauthlib, absl-py, tensorboard, sentencepiece, kaldilm, kaldialign - Successfully installed absl-py-0.13.0 cachetools-4.2.2 certifi-2021.5.30 charset-normalizer-2.0.4 google-auth-1.35.0 google-auth-oaut - hlib-0.4.5 grpcio-1.39.0 idna-3.2 kaldialign-0.2 kaldilm-1.8 markdown-3.3.4 oauthlib-3.1.1 protobuf-3.17.3 pyasn1-0.4.8 pyasn1-module - s-0.2.8 requests-2.26.0 requests-oauthlib-1.3.0 rsa-4.7.2 sentencepiece-0.1.96 six-1.16.0 tensorboard-2.6.0 tensorboard-data-server-0 - .6.1 tensorboard-plugin-wit-1.8.0 urllib3-1.26.6 werkzeug-2.0.1 + (test-icefall) kuangfangjun:tmp$ cd icefall/ + (test-icefall) kuangfangjun:icefall$ pip install -r ./requirements.txt Test Your Installation ---------------------- To test that your installation is successful, let us run the `yesno recipe `_ -on CPU. +on ``CPU``. Data preparation ~~~~~~~~~~~~~~~~ .. code-block:: bash - $ export PYTHONPATH=/tmp/icefall:$PYTHONPATH - $ cd /tmp/icefall - $ cd egs/yesno/ASR - $ ./prepare.sh + (test-icefall) kuangfangjun:icefall$ export PYTHONPATH=/tmp/icefall:$PYTHONPATH + + (test-icefall) kuangfangjun:icefall$ cd /tmp/icefall + + (test-icefall) kuangfangjun:icefall$ cd egs/yesno/ASR + + (test-icefall) kuangfangjun:ASR$ ./prepare.sh + The log of running ``./prepare.sh`` is: .. code-block:: - 2023-05-12 17:55:21 (prepare.sh:27:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download - 2023-05-12 17:55:21 (prepare.sh:30:main) Stage 0: Download data - /tmp/icefall/egs/yesno/ASR/download/waves_yesno.tar.gz: 100%|_______________________________________________________________| 4.70M/4.70M [06:54<00:00, 11.4kB/s] - 2023-05-12 18:02:19 (prepare.sh:39:main) Stage 1: Prepare yesno manifest - 2023-05-12 18:02:21 (prepare.sh:45:main) Stage 2: Compute fbank for yesno - 2023-05-12 18:02:23,199 INFO [compute_fbank_yesno.py:65] Processing train - Extracting and storing features: 100%|_______________________________________________________________| 90/90 [00:00<00:00, 212.60it/s] - 2023-05-12 18:02:23,640 INFO [compute_fbank_yesno.py:65] Processing test - Extracting and storing features: 100%|_______________________________________________________________| 30/30 [00:00<00:00, 304.53it/s] - 2023-05-12 18:02:24 (prepare.sh:51:main) Stage 3: Prepare lang - 2023-05-12 18:02:26 (prepare.sh:66:main) Stage 4: Prepare G - /project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):79 - [I] Reading \data\ section. - /project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):140 - [I] Reading \1-grams: section. - 2023-05-12 18:02:26 (prepare.sh:92:main) Stage 5: Compile HLG - 2023-05-12 18:02:28,581 INFO [compile_hlg.py:124] Processing data/lang_phone - 2023-05-12 18:02:28,582 INFO [lexicon.py:171] Converting L.pt to Linv.pt - 2023-05-12 18:02:28,609 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3 - 2023-05-12 18:02:28,610 INFO [compile_hlg.py:52] Loading G.fst.txt - 2023-05-12 18:02:28,611 INFO [compile_hlg.py:62] Intersecting L and G - 2023-05-12 18:02:28,613 INFO [compile_hlg.py:64] LG shape: (4, None) - 2023-05-12 18:02:28,613 INFO [compile_hlg.py:66] Connecting LG - 2023-05-12 18:02:28,614 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None) - 2023-05-12 18:02:28,614 INFO [compile_hlg.py:70] - 2023-05-12 18:02:28,614 INFO [compile_hlg.py:71] Determinizing LG - 2023-05-12 18:02:28,615 INFO [compile_hlg.py:74] - 2023-05-12 18:02:28,615 INFO [compile_hlg.py:76] Connecting LG after k2.determinize - 2023-05-12 18:02:28,615 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG - 2023-05-12 18:02:28,616 INFO [compile_hlg.py:91] LG shape after k2.remove_epsilon: (6, None) - 2023-05-12 18:02:28,617 INFO [compile_hlg.py:96] Arc sorting LG - 2023-05-12 18:02:28,617 INFO [compile_hlg.py:99] Composing H and LG - 2023-05-12 18:02:28,619 INFO [compile_hlg.py:106] Connecting LG - 2023-05-12 18:02:28,619 INFO [compile_hlg.py:109] Arc sorting LG - 2023-05-12 18:02:28,619 INFO [compile_hlg.py:111] HLG.shape: (8, None) - 2023-05-12 18:02:28,619 INFO [compile_hlg.py:127] Saving HLG.pt to data/lang_phone - + 2023-07-27 12:41:39 (prepare.sh:27:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download + 2023-07-27 12:41:39 (prepare.sh:30:main) Stage 0: Download data + /tmp/icefall/egs/yesno/ASR/download/waves_yesno.tar.gz: 100%|___________________________________________________| 4.70M/4.70M [00:00<00:00, 11.1MB/s] + 2023-07-27 12:41:46 (prepare.sh:39:main) Stage 1: Prepare yesno manifest + 2023-07-27 12:41:50 (prepare.sh:45:main) Stage 2: Compute fbank for yesno + 2023-07-27 12:41:55,718 INFO [compute_fbank_yesno.py:65] Processing train + Extracting and storing features: 100%|_______________________________________________________________________________| 90/90 [00:01<00:00, 87.82it/s] + 2023-07-27 12:41:56,778 INFO [compute_fbank_yesno.py:65] Processing test + Extracting and storing features: 100%|______________________________________________________________________________| 30/30 [00:00<00:00, 256.92it/s] + 2023-07-27 12:41:57 (prepare.sh:51:main) Stage 3: Prepare lang + 2023-07-27 12:42:02 (prepare.sh:66:main) Stage 4: Prepare G + /project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):79 + [I] Reading \data\ section. + /project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):140 + [I] Reading \1-grams: section. + 2023-07-27 12:42:02 (prepare.sh:92:main) Stage 5: Compile HLG + 2023-07-27 12:42:07,275 INFO [compile_hlg.py:124] Processing data/lang_phone + 2023-07-27 12:42:07,276 INFO [lexicon.py:171] Converting L.pt to Linv.pt + 2023-07-27 12:42:07,309 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3 + 2023-07-27 12:42:07,310 INFO [compile_hlg.py:52] Loading G.fst.txt + 2023-07-27 12:42:07,314 INFO [compile_hlg.py:62] Intersecting L and G + 2023-07-27 12:42:07,323 INFO [compile_hlg.py:64] LG shape: (4, None) + 2023-07-27 12:42:07,323 INFO [compile_hlg.py:66] Connecting LG + 2023-07-27 12:42:07,323 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None) + 2023-07-27 12:42:07,323 INFO [compile_hlg.py:70] + 2023-07-27 12:42:07,323 INFO [compile_hlg.py:71] Determinizing LG + 2023-07-27 12:42:07,341 INFO [compile_hlg.py:74] + 2023-07-27 12:42:07,341 INFO [compile_hlg.py:76] Connecting LG after k2.determinize + 2023-07-27 12:42:07,341 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG + 2023-07-27 12:42:07,354 INFO [compile_hlg.py:91] LG shape after k2.remove_epsilon: (6, None) + 2023-07-27 12:42:07,445 INFO [compile_hlg.py:96] Arc sorting LG + 2023-07-27 12:42:07,445 INFO [compile_hlg.py:99] Composing H and LG + 2023-07-27 12:42:07,446 INFO [compile_hlg.py:106] Connecting LG + 2023-07-27 12:42:07,446 INFO [compile_hlg.py:109] Arc sorting LG + 2023-07-27 12:42:07,447 INFO [compile_hlg.py:111] HLG.shape: (8, None) + 2023-07-27 12:42:07,447 INFO [compile_hlg.py:127] Saving HLG.pt to data/lang_phone Training ~~~~~~~~ @@ -409,12 +430,13 @@ Now let us run the training part: .. code-block:: - $ export CUDA_VISIBLE_DEVICES="" - $ ./tdnn/train.py + (test-icefall) kuangfangjun:ASR$ export CUDA_VISIBLE_DEVICES="" + + (test-icefall) kuangfangjun:ASR$ ./tdnn/train.py .. CAUTION:: - We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU + We use ``export CUDA_VISIBLE_DEVICES=""`` so that `icefall`_ uses CPU even if there are GPUs available. .. hint:: @@ -432,53 +454,52 @@ The training log is given below: .. code-block:: - 2023-05-12 18:04:59,759 INFO [train.py:481] Training started - 2023-05-12 18:04:59,759 INFO [train.py:482] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, - 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, - 'reduction': 'sum', 'use_double_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'seed': 42, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, - 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, - 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3b7f09fa35e72589914f67089c0da9f196a92ca4', 'k2-git-date': 'Mon May 8 22:58:45 2023', - 'lhotse-version': '1.15.0.dev+git.6fcfced.clean', 'torch-version': '2.0.0+cu118', 'torch-cuda-available': False, 'torch-cuda-version': '11.8', 'python-version': '3.1', 'icefall-git-branch': 'master', - 'icefall-git-sha1': '30bde4b-clean', 'icefall-git-date': 'Thu May 11 17:37:47 2023', 'icefall-path': '/tmp/icefall', - 'k2-path': 'tmp/lib/python3.10/site-packages/k2-1.24.3.dev20230512+cuda11.8.torch2.0.0-py3.10-linux-x86_64.egg/k2/__init__.py', - 'lhotse-path': 'tmp/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'host', 'IP address': '0.0.0.0'}} - 2023-05-12 18:04:59,761 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt - 2023-05-12 18:04:59,764 INFO [train.py:495] device: cpu - 2023-05-12 18:04:59,791 INFO [asr_datamodule.py:146] About to get train cuts - 2023-05-12 18:04:59,791 INFO [asr_datamodule.py:244] About to get train cuts - 2023-05-12 18:04:59,852 INFO [asr_datamodule.py:149] About to create train dataset - 2023-05-12 18:04:59,852 INFO [asr_datamodule.py:199] Using SingleCutSampler. - 2023-05-12 18:04:59,852 INFO [asr_datamodule.py:205] About to create train dataloader - 2023-05-12 18:04:59,853 INFO [asr_datamodule.py:218] About to get test cuts - 2023-05-12 18:04:59,853 INFO [asr_datamodule.py:252] About to get test cuts - 2023-05-12 18:04:59,986 INFO [train.py:422] Epoch 0, batch 0, loss[loss=1.065, over 2436.00 frames. ], tot_loss[loss=1.065, over 2436.00 frames. ], batch size: 4 - 2023-05-12 18:05:00,352 INFO [train.py:422] Epoch 0, batch 10, loss[loss=0.4561, over 2828.00 frames. ], tot_loss[loss=0.7076, over 22192.90 frames. ], batch size: 4 - 2023-05-12 18:05:00,691 INFO [train.py:444] Epoch 0, validation loss=0.9002, over 18067.00 frames. - 2023-05-12 18:05:00,996 INFO [train.py:422] Epoch 0, batch 20, loss[loss=0.2555, over 2695.00 frames. ], tot_loss[loss=0.484, over 34971.47 frames. ], batch size: 5 - 2023-05-12 18:05:01,217 INFO [train.py:444] Epoch 0, validation loss=0.4688, over 18067.00 frames. - 2023-05-12 18:05:01,251 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-0.pt - 2023-05-12 18:05:01,389 INFO [train.py:422] Epoch 1, batch 0, loss[loss=0.2532, over 2436.00 frames. ], tot_loss[loss=0.2532, over 2436.00 frames. ], batch size: 4 - 2023-05-12 18:05:01,637 INFO [train.py:422] Epoch 1, batch 10, loss[loss=0.1139, over 2828.00 frames. ], tot_loss[loss=0.1592, over 22192.90 frames. ], batch size: 4 - 2023-05-12 18:05:01,859 INFO [train.py:444] Epoch 1, validation loss=0.1629, over 18067.00 frames. - 2023-05-12 18:05:02,094 INFO [train.py:422] Epoch 1, batch 20, loss[loss=0.0767, over 2695.00 frames. ], tot_loss[loss=0.118, over 34971.47 frames. ], batch size: 5 - 2023-05-12 18:05:02,350 INFO [train.py:444] Epoch 1, validation loss=0.06778, over 18067.00 frames. - 2023-05-12 18:05:02,395 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-1.pt + 2023-07-27 12:50:51,936 INFO [train.py:481] Training started + 2023-07-27 12:50:51,936 INFO [train.py:482] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'seed': 42, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4c05309499a08454997adf500b56dcc629e35ae5', 'k2-git-date': 'Tue Jul 25 16:23:36 2023', 'lhotse-version': '1.16.0.dev+git.7640d66.clean', 'torch-version': '1.13.0+cu116', 'torch-cuda-available': False, 'torch-cuda-version': '11.6', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '3fb0a43-clean', 'icefall-git-date': 'Thu Jul 27 12:36:05 2023', 'icefall-path': '/tmp/icefall', 'k2-path': '/star-fj/fangjun/test-icefall/lib/python3.8/site-packages/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/test-icefall/lib/python3.8/site-packages/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-1-1220091118-57c4d55446-sph26', 'IP address': '10.177.77.20'}} + 2023-07-27 12:50:51,941 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt + 2023-07-27 12:50:51,949 INFO [train.py:495] device: cpu + 2023-07-27 12:50:51,965 INFO [asr_datamodule.py:146] About to get train cuts + 2023-07-27 12:50:51,965 INFO [asr_datamodule.py:244] About to get train cuts + 2023-07-27 12:50:51,967 INFO [asr_datamodule.py:149] About to create train dataset + 2023-07-27 12:50:51,967 INFO [asr_datamodule.py:199] Using SingleCutSampler. + 2023-07-27 12:50:51,967 INFO [asr_datamodule.py:205] About to create train dataloader + 2023-07-27 12:50:51,968 INFO [asr_datamodule.py:218] About to get test cuts + 2023-07-27 12:50:51,968 INFO [asr_datamodule.py:252] About to get test cuts + 2023-07-27 12:50:52,565 INFO [train.py:422] Epoch 0, batch 0, loss[loss=1.065, over 2436.00 frames. ], tot_loss[loss=1.065, over 2436.00 frames. ], batch size: 4 + 2023-07-27 12:50:53,681 INFO [train.py:422] Epoch 0, batch 10, loss[loss=0.4561, over 2828.00 frames. ], tot_loss[loss=0.7076, over 22192.90 frames.], batch size: 4 + 2023-07-27 12:50:54,167 INFO [train.py:444] Epoch 0, validation loss=0.9002, over 18067.00 frames. + 2023-07-27 12:50:55,011 INFO [train.py:422] Epoch 0, batch 20, loss[loss=0.2555, over 2695.00 frames. ], tot_loss[loss=0.484, over 34971.47 frames. ], batch size: 5 + 2023-07-27 12:50:55,331 INFO [train.py:444] Epoch 0, validation loss=0.4688, over 18067.00 frames. + 2023-07-27 12:50:55,368 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-0.pt + 2023-07-27 12:50:55,633 INFO [train.py:422] Epoch 1, batch 0, loss[loss=0.2532, over 2436.00 frames. ], tot_loss[loss=0.2532, over 2436.00 frames. ], + batch size: 4 + 2023-07-27 12:50:56,242 INFO [train.py:422] Epoch 1, batch 10, loss[loss=0.1139, over 2828.00 frames. ], tot_loss[loss=0.1592, over 22192.90 frames.], batch size: 4 + 2023-07-27 12:50:56,522 INFO [train.py:444] Epoch 1, validation loss=0.1627, over 18067.00 frames. + 2023-07-27 12:50:57,209 INFO [train.py:422] Epoch 1, batch 20, loss[loss=0.07055, over 2695.00 frames. ], tot_loss[loss=0.1175, over 34971.47 frames.], batch size: 5 + 2023-07-27 12:50:57,600 INFO [train.py:444] Epoch 1, validation loss=0.07091, over 18067.00 frames. + 2023-07-27 12:50:57,640 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-1.pt + 2023-07-27 12:50:57,847 INFO [train.py:422] Epoch 2, batch 0, loss[loss=0.07731, over 2436.00 frames. ], tot_loss[loss=0.07731, over 2436.00 frames.], batch size: 4 + 2023-07-27 12:50:58,427 INFO [train.py:422] Epoch 2, batch 10, loss[loss=0.04391, over 2828.00 frames. ], tot_loss[loss=0.05341, over 22192.90 frames. ], batch size: 4 + 2023-07-27 12:50:58,884 INFO [train.py:444] Epoch 2, validation loss=0.04384, over 18067.00 frames. + 2023-07-27 12:50:59,387 INFO [train.py:422] Epoch 2, batch 20, loss[loss=0.03458, over 2695.00 frames. ], tot_loss[loss=0.04616, over 34971.47 frames. ], batch size: 5 + 2023-07-27 12:50:59,707 INFO [train.py:444] Epoch 2, validation loss=0.03379, over 18067.00 frames. + 2023-07-27 12:50:59,758 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-2.pt - ... ... + ... ... - 2023-05-12 18:05:14,789 INFO [train.py:422] Epoch 13, batch 0, loss[loss=0.01056, over 2436.00 frames. ], tot_loss[loss=0.01056, over 2436.00 frames. ], batch size: 4 - 2023-05-12 18:05:15,016 INFO [train.py:422] Epoch 13, batch 10, loss[loss=0.009022, over 2828.00 frames. ], tot_loss[loss=0.009985, over 22192.90 frames. ], batch size: 4 - 2023-05-12 18:05:15,271 INFO [train.py:444] Epoch 13, validation loss=0.01088, over 18067.00 frames. - 2023-05-12 18:05:15,497 INFO [train.py:422] Epoch 13, batch 20, loss[loss=0.01174, over 2695.00 frames. ], tot_loss[loss=0.01077, over 34971.47 frames. ], batch size: 5 - 2023-05-12 18:05:15,747 INFO [train.py:444] Epoch 13, validation loss=0.01087, over 18067.00 frames. - 2023-05-12 18:05:15,783 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-13.pt - 2023-05-12 18:05:15,921 INFO [train.py:422] Epoch 14, batch 0, loss[loss=0.01045, over 2436.00 frames. ], tot_loss[loss=0.01045, over 2436.00 frames. ], batch size: 4 - 2023-05-12 18:05:16,146 INFO [train.py:422] Epoch 14, batch 10, loss[loss=0.008957, over 2828.00 frames. ], tot_loss[loss=0.009903, over 22192.90 frames. ], batch size: 4 - 2023-05-12 18:05:16,374 INFO [train.py:444] Epoch 14, validation loss=0.01092, over 18067.00 frames. - 2023-05-12 18:05:16,598 INFO [train.py:422] Epoch 14, batch 20, loss[loss=0.01169, over 2695.00 frames. ], tot_loss[loss=0.01065, over 34971.47 frames. ], batch size: 5 - 2023-05-12 18:05:16,824 INFO [train.py:444] Epoch 14, validation loss=0.01077, over 18067.00 frames. - 2023-05-12 18:05:16,862 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-14.pt - 2023-05-12 18:05:16,865 INFO [train.py:555] Done! + 2023-07-27 12:51:23,433 INFO [train.py:422] Epoch 13, batch 0, loss[loss=0.01054, over 2436.00 frames. ], tot_loss[loss=0.01054, over 2436.00 frames. ], batch size: 4 + 2023-07-27 12:51:23,980 INFO [train.py:422] Epoch 13, batch 10, loss[loss=0.009014, over 2828.00 frames. ], tot_loss[loss=0.009974, over 22192.90 frames. ], batch size: 4 + 2023-07-27 12:51:24,489 INFO [train.py:444] Epoch 13, validation loss=0.01085, over 18067.00 frames. + 2023-07-27 12:51:25,258 INFO [train.py:422] Epoch 13, batch 20, loss[loss=0.01172, over 2695.00 frames. ], tot_loss[loss=0.01055, over 34971.47 frames. ], batch size: 5 + 2023-07-27 12:51:25,621 INFO [train.py:444] Epoch 13, validation loss=0.01074, over 18067.00 frames. + 2023-07-27 12:51:25,699 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-13.pt + 2023-07-27 12:51:25,866 INFO [train.py:422] Epoch 14, batch 0, loss[loss=0.01044, over 2436.00 frames. ], tot_loss[loss=0.01044, over 2436.00 frames. ], batch size: 4 + 2023-07-27 12:51:26,844 INFO [train.py:422] Epoch 14, batch 10, loss[loss=0.008942, over 2828.00 frames. ], tot_loss[loss=0.01, over 22192.90 frames. ], batch size: 4 + 2023-07-27 12:51:27,221 INFO [train.py:444] Epoch 14, validation loss=0.01082, over 18067.00 frames. + 2023-07-27 12:51:27,970 INFO [train.py:422] Epoch 14, batch 20, loss[loss=0.01169, over 2695.00 frames. ], tot_loss[loss=0.01054, over 34971.47 frames. ], batch size: 5 + 2023-07-27 12:51:28,247 INFO [train.py:444] Epoch 14, validation loss=0.01073, over 18067.00 frames. + 2023-07-27 12:51:28,323 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-14.pt + 2023-07-27 12:51:28,326 INFO [train.py:555] Done! Decoding ~~~~~~~~ @@ -487,42 +508,32 @@ Let us use the trained model to decode the test set: .. code-block:: - $ ./tdnn/decode.py + (test-icefall) kuangfangjun:ASR$ ./tdnn/decode.py -The decoding log is: + 2023-07-27 12:55:12,840 INFO [decode.py:263] Decoding started + 2023-07-27 12:55:12,840 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4c05309499a08454997adf500b56dcc629e35ae5', 'k2-git-date': 'Tue Jul 25 16:23:36 2023', 'lhotse-version': '1.16.0.dev+git.7640d66.clean', 'torch-version': '1.13.0+cu116', 'torch-cuda-available': False, 'torch-cuda-version': '11.6', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '3fb0a43-clean', 'icefall-git-date': 'Thu Jul 27 12:36:05 2023', 'icefall-path': '/tmp/icefall', 'k2-path': '/star-fj/fangjun/test-icefall/lib/python3.8/site-packages/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/test-icefall/lib/python3.8/site-packages/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-1-1220091118-57c4d55446-sph26', 'IP address': '10.177.77.20'}} + 2023-07-27 12:55:12,841 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt + 2023-07-27 12:55:12,855 INFO [decode.py:273] device: cpu + 2023-07-27 12:55:12,868 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt'] + 2023-07-27 12:55:12,882 INFO [asr_datamodule.py:218] About to get test cuts + 2023-07-27 12:55:12,883 INFO [asr_datamodule.py:252] About to get test cuts + 2023-07-27 12:55:13,157 INFO [decode.py:204] batch 0/?, cuts processed until now is 4 + 2023-07-27 12:55:13,701 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt + 2023-07-27 12:55:13,702 INFO [utils.py:564] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ] + 2023-07-27 12:55:13,704 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt + 2023-07-27 12:55:13,704 INFO [decode.py:316] Done! -.. code-block:: - 2023-05-12 18:08:30,482 INFO [decode.py:263] Decoding started - 2023-05-12 18:08:30,483 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, - 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'), - 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, - 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3b7f09fa35e72589914f67089c0da9f196a92ca4', 'k2-git-date': 'Mon May 8 22:58:45 2023', - 'lhotse-version': '1.15.0.dev+git.6fcfced.clean', 'torch-version': '2.0.0+cu118', 'torch-cuda-available': False, 'torch-cuda-version': '11.8', 'python-version': '3.1', 'icefall-git-branch': 'master', - 'icefall-git-sha1': '30bde4b-clean', 'icefall-git-date': 'Thu May 11 17:37:47 2023', 'icefall-path': '/tmp/icefall', - 'k2-path': '/tmp/lib/python3.10/site-packages/k2-1.24.3.dev20230512+cuda11.8.torch2.0.0-py3.10-linux-x86_64.egg/k2/__init__.py', - 'lhotse-path': '/tmp/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'host', 'IP address': '0.0.0.0'}} - 2023-05-12 18:08:30,483 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt - 2023-05-12 18:08:30,487 INFO [decode.py:273] device: cpu - 2023-05-12 18:08:30,513 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt'] - 2023-05-12 18:08:30,521 INFO [asr_datamodule.py:218] About to get test cuts - 2023-05-12 18:08:30,521 INFO [asr_datamodule.py:252] About to get test cuts - 2023-05-12 18:08:30,675 INFO [decode.py:204] batch 0/?, cuts processed until now is 4 - 2023-05-12 18:08:30,923 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt - 2023-05-12 18:08:30,924 INFO [utils.py:558] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ] - 2023-05-12 18:08:30,925 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt - 2023-05-12 18:08:30,925 INFO [decode.py:316] Done! - -**Congratulations!** You have successfully setup the environment and have run the first recipe in ``icefall``. +**Congratulations!** You have successfully setup the environment and have run the first recipe in `icefall`_. Have fun with ``icefall``! YouTube Video ------------- -We provide the following YouTube video showing how to install ``icefall``. +We provide the following YouTube video showing how to install `icefall`_. It also shows how to debug various problems that you may encounter while -using ``icefall``. +using `icefall`_. .. note:: From 751bb6ff1a933c69a5ad4aebe8e24972f14dd691 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 28 Jul 2023 10:34:40 +0800 Subject: [PATCH 011/113] Add docker image for icefall (#1189) --- .github/workflows/build-docker-image.yml | 45 ++++++++++++++++ .github/workflows/run-docker-image.yml | 66 ++++++++++++++++++++++++ docker/README.md | 15 ++++++ docker/torch1.12.1-cuda11.3.dockerfile | 62 ++++++++++++++++++++++ docker/torch1.13.0-cuda11.6.dockerfile | 64 +++++++++++++++++++++++ docker/torch1.9.0-cuda10.2.dockerfile | 62 ++++++++++++++++++++++ docker/torch2.0.0-cuda11.7.dockerfile | 62 ++++++++++++++++++++++ 7 files changed, 376 insertions(+) create mode 100644 .github/workflows/build-docker-image.yml create mode 100644 .github/workflows/run-docker-image.yml create mode 100644 docker/torch1.12.1-cuda11.3.dockerfile create mode 100644 docker/torch1.13.0-cuda11.6.dockerfile create mode 100644 docker/torch1.9.0-cuda10.2.dockerfile create mode 100644 docker/torch2.0.0-cuda11.7.dockerfile diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml new file mode 100644 index 000000000..327f0ee45 --- /dev/null +++ b/.github/workflows/build-docker-image.yml @@ -0,0 +1,45 @@ +# see also +# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages +name: Build docker image +on: + workflow_dispatch: + +concurrency: + group: build_docker-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-docker-image: + name: ${{ matrix.image }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + image: ["torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + + steps: + # refer to https://github.com/actions/checkout + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Rename + shell: bash + run: | + image=${{ matrix.image }} + mv -v ./docker/$image.dockerfile ./Dockerfile + + - name: Log in to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Build and push + uses: docker/build-push-action@v4 + with: + context: . + file: ./Dockerfile + push: true + tags: k2fsa/icefall:${{ matrix.image }} diff --git a/.github/workflows/run-docker-image.yml b/.github/workflows/run-docker-image.yml new file mode 100644 index 000000000..d0ac11071 --- /dev/null +++ b/.github/workflows/run-docker-image.yml @@ -0,0 +1,66 @@ +name: Run docker image +on: + workflow_dispatch: + +concurrency: + group: run_docker_image-${{ github.ref }} + cancel-in-progress: true + +jobs: + run-docker-image: + name: ${{ matrix.image }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + image: ["torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + steps: + # refer to https://github.com/actions/checkout + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Run the build process with Docker + uses: addnab/docker-run-action@v3 + with: + image: k2fsa/icefall:${{ matrix.image }} + run: | + uname -a + cat /etc/*release + + nvcc --version + + which nvcc + cuda_dir=$(dirname $(which nvcc)) + echo "cuda_dir: $cuda_dir" + + find $cuda_dir -name libcuda.so* + echo "--------------------" + + find / -name libcuda.so* 2>/dev/null + + pushd /opt/conda/lib/stubs && ln -s libcuda.so libcuda.so.1 && popd + + export LD_LIBRARY_PATH=/opt/conda/lib/stubs:$LD_LIBRARY_PATH + echo "LD_LIBRARY_PATH $LD_LIBRARY_PATH" + + python3 --version + which python3 + + echo "----------torch----------" + python3 -m torch.utils.collect_env + + echo "----------k2----------" + python3 -c "import k2; print(k2.__file__)" + python3 -c "import k2; print(k2.__version__)" + python3 -m k2.version + + echo "----------lhotse----------" + python3 -c "import lhotse; print(lhotse.__file__)" + python3 -c "import lhotse; print(lhotse.__version__)" + + echo "----------kaldifeat----------" + python3 -c "import kaldifeat; print(kaldifeat.__file__)" + python3 -c "import kaldifeat; print(kaldifeat.__version__)" + diff --git a/docker/README.md b/docker/README.md index c14b9bf75..19959bfe6 100644 --- a/docker/README.md +++ b/docker/README.md @@ -1,5 +1,20 @@ # icefall dockerfile +## Download from dockerhub + +You can find pre-built docker image for icefall at the following address: + + + +Example usage: + +```bash +docker run --gpus all --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash +``` + + +## Build from dockerfile + 2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. diff --git a/docker/torch1.12.1-cuda11.3.dockerfile b/docker/torch1.12.1-cuda11.3.dockerfile new file mode 100644 index 000000000..c5e252abb --- /dev/null +++ b/docker/torch1.12.1-cuda11.3.dockerfile @@ -0,0 +1,62 @@ +FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +ARG K2_VERSION="1.24.3.dev20230725+cuda11.3.torch1.12.1" +ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.3.torch1.12.1" +ARG TORCHAUDIO_VERSION="0.12.1+cu113" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch1.13.0-cuda11.6.dockerfile b/docker/torch1.13.0-cuda11.6.dockerfile new file mode 100644 index 000000000..bcbf8b599 --- /dev/null +++ b/docker/torch1.13.0-cuda11.6.dockerfile @@ -0,0 +1,64 @@ +FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-runtime + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +ARG K2_VERSION="1.24.3.dev20230725+cuda11.6.torch1.13.0" +ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.6.torch1.13.0" +ARG TORCHAUDIO_VERSION="0.13.0+cu116" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +ENV LD_LIBRARY_PATH /opt/conda/lib/stubs:$LD_LIBRARY_PATH + +WORKDIR /workspace/icefall diff --git a/docker/torch1.9.0-cuda10.2.dockerfile b/docker/torch1.9.0-cuda10.2.dockerfile new file mode 100644 index 000000000..7553fcf86 --- /dev/null +++ b/docker/torch1.9.0-cuda10.2.dockerfile @@ -0,0 +1,62 @@ +FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-runtime + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +ARG K2_VERSION="1.24.3.dev20230726+cuda10.2.torch1.9.0" +ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda10.2.torch1.9.0" +ARG TORCHAUDIO_VERSION="0.9.0" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile new file mode 100644 index 000000000..c11c0bd67 --- /dev/null +++ b/docker/torch2.0.0-cuda11.7.dockerfile @@ -0,0 +1,62 @@ +FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +ARG K2_VERSION="1.24.3.dev20230718+cuda11.7.torch2.0.0" +ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.7.torch2.0.0" +ARG TORCHAUDIO_VERSION="2.0.0+cu117" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall From 375520d419826485a206115d66b1471934295081 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 28 Jul 2023 15:43:08 +0800 Subject: [PATCH 012/113] Run the yesno recipe with docker in GitHub actions (#1191) --- .github/workflows/run-docker-image.yml | 34 +++++++++++++++++++++++--- docker/torch1.12.1-cuda11.3.dockerfile | 12 +++++++-- docker/torch1.13.0-cuda11.6.dockerfile | 10 +++++++- docker/torch1.9.0-cuda10.2.dockerfile | 30 ++++++++++++++++++++--- docker/torch2.0.0-cuda11.7.dockerfile | 12 +++++++-- 5 files changed, 86 insertions(+), 12 deletions(-) diff --git a/.github/workflows/run-docker-image.yml b/.github/workflows/run-docker-image.yml index d0ac11071..12604a132 100644 --- a/.github/workflows/run-docker-image.yml +++ b/.github/workflows/run-docker-image.yml @@ -25,12 +25,23 @@ jobs: uses: addnab/docker-run-action@v3 with: image: k2fsa/icefall:${{ matrix.image }} + shell: bash run: | uname -a cat /etc/*release nvcc --version + # For torch1.9.0-cuda10.2 + export LD_LIBRARY_PATH=/usr/local/cuda-10.2/compat:$LD_LIBRARY_PATH + + # For torch1.12.1-cuda11.3 + export LD_LIBRARY_PATH=/usr/local/cuda-11.3/compat:$LD_LIBRARY_PATH + + # For torch2.0.0-cuda11.7 + export LD_LIBRARY_PATH=/usr/local/cuda-11.7/compat:$LD_LIBRARY_PATH + + which nvcc cuda_dir=$(dirname $(which nvcc)) echo "cuda_dir: $cuda_dir" @@ -40,20 +51,26 @@ jobs: find / -name libcuda.so* 2>/dev/null - pushd /opt/conda/lib/stubs && ln -s libcuda.so libcuda.so.1 && popd + # for torch1.13.0-cuda11.6 + if [ -e /opt/conda/lib/stubs/libcuda.so ]; then + cd /opt/conda/lib/stubs && ln -s libcuda.so libcuda.so.1 && cd - + export LD_LIBRARY_PATH=/opt/conda/lib/stubs:$LD_LIBRARY_PATH + fi - export LD_LIBRARY_PATH=/opt/conda/lib/stubs:$LD_LIBRARY_PATH - echo "LD_LIBRARY_PATH $LD_LIBRARY_PATH" + find / -name libcuda.so* 2>/dev/null + echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH" python3 --version which python3 + python3 -m pip list + echo "----------torch----------" python3 -m torch.utils.collect_env echo "----------k2----------" python3 -c "import k2; print(k2.__file__)" - python3 -c "import k2; print(k2.__version__)" + python3 -c "import k2; print(k2.__dev_version__)" python3 -m k2.version echo "----------lhotse----------" @@ -64,3 +81,12 @@ jobs: python3 -c "import kaldifeat; print(kaldifeat.__file__)" python3 -c "import kaldifeat; print(kaldifeat.__version__)" + echo "Test yesno recipe" + + cd egs/yesno/ASR + + ./prepare.sh + + ./tdnn/train.py + + ./tdnn/decode.py diff --git a/docker/torch1.12.1-cuda11.3.dockerfile b/docker/torch1.12.1-cuda11.3.dockerfile index c5e252abb..5338bdca7 100644 --- a/docker/torch1.12.1-cuda11.3.dockerfile +++ b/docker/torch1.12.1-cuda11.3.dockerfile @@ -1,4 +1,4 @@ -FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime +FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel ENV LC_ALL C.UTF-8 @@ -51,7 +51,15 @@ RUN pip install --no-cache-dir \ sentencepiece>=0.1.96 \ tensorboard \ typeguard \ - dill + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ diff --git a/docker/torch1.13.0-cuda11.6.dockerfile b/docker/torch1.13.0-cuda11.6.dockerfile index bcbf8b599..4d2f96c8e 100644 --- a/docker/torch1.13.0-cuda11.6.dockerfile +++ b/docker/torch1.13.0-cuda11.6.dockerfile @@ -51,7 +51,15 @@ RUN pip install --no-cache-dir \ sentencepiece>=0.1.96 \ tensorboard \ typeguard \ - dill + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ diff --git a/docker/torch1.9.0-cuda10.2.dockerfile b/docker/torch1.9.0-cuda10.2.dockerfile index 7553fcf86..a7cef6dc8 100644 --- a/docker/torch1.9.0-cuda10.2.dockerfile +++ b/docker/torch1.9.0-cuda10.2.dockerfile @@ -1,4 +1,4 @@ -FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-runtime +FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-devel ENV LC_ALL C.UTF-8 @@ -13,6 +13,13 @@ LABEL k2_version=${K2_VERSION} LABEL kaldifeat_version=${KALDIFEAT_VERSION} LABEL github_repo="https://github.com/k2-fsa/icefall" +# see https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/ + +RUN rm /etc/apt/sources.list.d/cuda.list && \ + rm /etc/apt/sources.list.d/nvidia-ml.list && \ + apt-key del 7fa2af80 + + RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ @@ -37,8 +44,15 @@ RUN apt-get update && \ zlib1g-dev \ && rm -rf /var/lib/apt/lists/* +RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb && \ + dpkg -i cuda-keyring_1.0-1_all.deb && \ + rm -v cuda-keyring_1.0-1_all.deb && \ + apt-get update && \ + rm -rf /var/lib/apt/lists/* + # Install dependencies -RUN pip install --no-cache-dir \ +RUN pip uninstall -y tqdm && \ + pip install -U --no-cache-dir \ torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ @@ -51,7 +65,17 @@ RUN pip install --no-cache-dir \ sentencepiece>=0.1.96 \ tensorboard \ typeguard \ - dill + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz \ + tqdm>=4.63.0 + RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile index c11c0bd67..d91fbc24f 100644 --- a/docker/torch2.0.0-cuda11.7.dockerfile +++ b/docker/torch2.0.0-cuda11.7.dockerfile @@ -1,4 +1,4 @@ -FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime +FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel ENV LC_ALL C.UTF-8 @@ -51,7 +51,15 @@ RUN pip install --no-cache-dir \ sentencepiece>=0.1.96 \ tensorboard \ typeguard \ - dill + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ From bcabaf896c0eadef1ed8d86907847c367e4bd14f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 1 Aug 2023 12:28:34 +0800 Subject: [PATCH 013/113] Add doc describing how to run icefall within a docker container (#1194) --- docs/source/docker/img/docker-hub.png | Bin 0 -> 364778 bytes docs/source/docker/index.rst | 17 +++ docs/source/docker/intro.rst | 171 ++++++++++++++++++++++++++ docs/source/index.rst | 4 +- docs/source/installation/index.rst | 5 + 5 files changed, 196 insertions(+), 1 deletion(-) create mode 100644 docs/source/docker/img/docker-hub.png create mode 100644 docs/source/docker/index.rst create mode 100644 docs/source/docker/intro.rst diff --git a/docs/source/docker/img/docker-hub.png b/docs/source/docker/img/docker-hub.png new file mode 100644 index 0000000000000000000000000000000000000000..a9e7715b0b41d49cf6a2717d2f3f42c193269134 GIT binary patch literal 364778 zcmbSy1#n(Hu4p)4l7_iqW@cu_hMAcgW@c`~%*@P5lQhuK7i^fBndzmwd-v{rZ)fhm z@64RBk8R18EK9QN6QL+C0T25b76b$YUP@9_2?PXC2?7EI06=||bk+RO2LXXCw-gak zloAmkRCKgAv$Qb<0g;SIN`Y2PSj6x@e}6wDZ46c{3|@mcgosBGhRY?y6D1=B6^3C# zf;bt@2cwA$2dFu#iuS5f)nkaFE$5dmxCmh&o9fbrM~*xP5S^T5+^==AIPLjzaos2K z+)rc@fJ73j)5?dSgG!>BN(+1QVSkxo`t$zv$MJG=BV-qxw(SHja zKEe$5<}<%~8unBeeURx9jNQBAv&S%Tlemjg z0yX*qS=P<$C9YO#V^8i$WPS%VI>DHBOu#$t4?itLU5@etB;bI70fZTHF%31Qc4mMJ z#n4EJJSe1_5j2QoOLM|(L1S?NjPX*hq`KehX@@8HVBXBeLldMmJvJ^bC#?>W8`Pm3 z8)SMTdK{ul=879cXYogZbbYZ+f;|kJq#)>|Kldcm(+>y@5S$GV9e?Nxs5e4DaDYWE z)`cL7qzH{Zl$w82E(lt{xB*Tngx3#W2P|B03PDs(aC-mI4enOR_Fm&j@TXpD#!vi# zOn|@(B*=W>qj&@pK$-|U5~^cRXPzJ>L}chzuBar0R37>ycPW56ggVbh5&nk7`Exsn zZb*A>m?8r!pev~R2S7H|vKGSz0vM3S2#4U`zai#8+=(_3a=j_%0I35{5Q4J_cj3~C za~7cAPXZ9!LE_Gj(jb%#j)5Rzk_00bQX>kH6C8Cprc7#3@3XKP1;f6*U8qhM* zV+qJ|hK6ZIYKBk^su&>{e`Y#OC-)lec_d<=(`DKN*bBnT zmluy0TW^6kRH4{hky@e$BqIQ^Fn(U9L$E`vLk6QD76~;9e8j|0LgJiM5n3XA!e%6U zl;;RzQnK$%zmm|zrOB(vRO62%QzR86&v$orZgw(vFot1ve-En;+Rx~7y(0!-1T(2WD^X-erUEADveU>?~_q{Wa$R-+-hDLPLQ zAju|)I;nx0&pFsmbbnI$kn9Ha#_NWmT&tW`HQ-2mj(cvoyuhkqmUkAyD#Ys6YSl`A z7W-)Ji0VjtHe)u&%7Ar@6~lsdW})h7UW#Gm+r=caR5M&P*O7UB224uW=+Ul&wbjMd zIn^p|Id&yyso8aUt$NvdU2Dod-}#jVgnDFqisE0;S<-;zx$rYi1`wNL9cyEBtAQ#Iol zH*|t_%6p{0u3lVNly4Ags4!n|A=Iv=t*5o3^Q1dW4oLP$mP@wZ!y1c7YcQ1`$~;%Q zQ!7)0uEf`-(SFw6*KE<~X<%yDTvSiOCU!70A`j~$npCeQ|b6Kwi-%Z!8+NG*iMO0;gdJDIfv-gJ~iM@my(Hs&V&n0Sm7SFJ)?Y2>C7>k{)4w}|lp$gS+637Hl; zykrPT1T#VGDCTq)%L4;v5i8w#1ep<1b z6WVLd!fpdg=v_3LH1}13Y6CS^+EqQ=elG8y*V)46zyB_-d}RYt*RY=Vw-39}J!(B# zUp@O?yyi6vb_*&w_;rmw@!>Gy1mkdXEZMEpU3wAqTxfk${o*k-Z=`qGL{Q~d`S8_S zxvzKvCG7iZ#@7t7jC0$s>yE7|WvnfY_Q#!7F{aV=ef5aT$u=EYcIw^cJeD2KUJ7Cc zV%G>kco}ZqRkGzL)mN3}>t$_D+Y?i|FS;#mUMJPd1x+Ouz?81OF2uXXGwVwBa{BeS z<(#JXfq-yuJOL)Za<0;^wFet5BKabGVIl}E^PInDai@5dl;Gy2t)b%v@w;5cjy0UJ z#8$gKJ$s|l{ZkE%S?Qm*v-Ka$no+mR9)eACb|^XU;f@agjqnpW~)EB=0UYF zoixgy!r)rVJ)!@#%W<##dTwSubw0hCN5{U?(63-q{aGJ}8;={uo_o!vd)K+5$l7kc zZTYN)Vdb*j>5+c_?Dp(&RrkK{oUfI@t7jkdy>~RM3~`-*#FO>TeH`^9p|2!TVKm2* zzv6BDrsWAgEW4pguZtr)jNfl%=-BJkYw%qE`C4C5FA?$J(`;v>hlkmu`Ni*wkeVm` zDE&Bi^q0Jwp$QIhd->PH%l(+m(#^M)?=9Qsk?R#b0yn*9BFmB&!wqqk1ViF*0%d;O zFKLf$+cG;T0d=POm-;l_X};Ev_!qCg+?O6&Q)xEkW5EFB4#Hp{=Ln#J5m5e~a9IQx z^5ZiJU~)d-vM3JDUM2h2;`xk-Aj?Z5&YpK!nKKj+ zVZL=IE?K#DY;T`05Yy5(-<9qK)ESZI5MG>KF>@Utv$EVE`aYc3NK*|dGdVdB>W@4C z1OoIk2;@f&^yB3R#rbDm9F!6S{BPx8ARu9uAQ1m(Bma^9dE!6bKQMn!!M}%sK!1Fp ze7x?tVE@$`P?-zXliQbWMS_-O-L>HQ2^s0sp$j)fRD>+sPU1ccX}`y*>>>TF2pZfj%b#O=;U z{I?d|ANfDU48(+gYvOFpN30>INGM|OXiCUV&q&Wm%nwUQNXYAGV#ci`D*g}fk1sxA z3uk8sZUzQ7H#d4W7J7R}a|R|ZE-nT}W(HdgP&6ZQ|_f35rnATPro+W&)HfE+{;R0sq_5JXB;NW~rWL>DGgVGdJ(#cXf^2DvXczACX% zGI_e#0vUY-Wkl~A2B}t3R$*|X<+sG5_^ObbmQY&FhI2D4Gp6_JuAVmgv-5GUo`Dy$ zJSXq5j`KCAlzX51)%ElHiTiBSW{=}oG-gs-+A+>+`Sz6B>^M)00rifxF87JRIXA2R zF*p}5v*qs?5D=h3f=EbzDG7)sLCWHZq^I+R2?+}`?_{zh8X|N!TjAn%IH1hp%ZN7~ zBn3x?fCxB|ARLp|=#5_)#=|D6e6X5_fofsmJcRm=d2|PMsO8;BU&OoVQwSDB7nz8yyUYc&0X z`hXK5LSEkJt9_tE2pon%4K+XAw?!7gL$)1&f2HG#DDA*NRpi#Afv=mq3xgK`6KDz~< zwr0ahem>gQK21~=?uNE^H71jh{eSY?K;p*-a^Vl;;#baL^0o)eExzN9UQ7C3u0DSI zd3Hd3KPCGAFWf);&;x2`N)$>KetmwMVAF8r+u_-L9J!xMrg!3Qzuhjxk<=?)hAUK< z`cOBZ?*LLVikUX2eM+uJ+@iw5(i<-R$bZhXY*}$B+tV(P|yc zPexB7(mCrBC4}P0NAnj2EL{OWfqGY0z zs{co#se?R>hd!{A1d#CRu_mKJf^Y-@(+ftA9+F?F&4Di8E)3%hSI)>r`H-F6V*2Uc zqvhkUa`+NgvOg9Iq;>&^f9K@MD{P{;w zip>WD)g1R@zbgtY`cjyH^b#Vu9?h1a@obVTI5}I?Y$umM?-$mLzRa=R2F^9j&+Wer z7ToUqQT=&Bgl}V-V1~Q-7!E2sQ5Rx9QM+4zVHl2_U{}@rVe5q0dZozW(c!tgQ*9!v zm5met6Iigw+Q8(~;F##_FLQj*J|14(WmoINHJ8Hs(ZfZYAajDshURWQ0_Y9_)YN}s zVwurFg8_jhNjkE6$jA_m)PH9E(=_R~R7f@R&t{+HpuNsi zhUkDFW{vRfjjwIOFKEieim2c=Ybo&$p@hO%xigpMKN})tq5gNy{!{4^!Xf*YBz1b^ zlkqBI7}4uO#tL%x!Im#Y;9&E)St%nT^$azci8jF{1IIAj8yB zCx&98|Iu2i#|Qu*z(Gt$h7-omASr(Eb3}fY5jfr`pz9NG$1Q*BeR#yIe=D11aC34y z=2Y@EWA~547qEd8zRxHaAtlwTiemd8kSB-)zy|wxvdHMkb#{YJYaJfTZv6`wOy4xq%(ED&co&w8;FvsWw4(!V35zkIgr!J-Ef8 z`JU?oO>TVP`~5cIpl2JODfRm-B1P0pZkz~3II!wT{N_VFR?&<9!bMCCX&|2RDp@cly z^x$iWC^l{9x zC8lK+Y|P;t2=d~}Wv_I0f>?8AZG)g#lV<56qbZdF6}?C!;?WRN!@9aUDsrKI*=^x5 zr;kZ@;zlF~%3(etd97!!-1vdd$>`7UP^g@`w@{l0{_ukiXn=`uW+&S0aq4v}`U`gW zqsc{%IEdL+X{qFHFR@~NBpzeD>P9cVYjV^V-`86XX=&L)?W1)R#qC%Zyz< zPhYEIM6!D~5lAln_G_N1UVWe5HI<%0PojoW+70}@_<(1;#4`9&qcAD*Q3<&lCr%^iNBpxKT!=hC4r=Z z@#++a@RK~mB-4~rbb6c#5&)5$N&n%EfU`+`4k$czxfw`z0iT<4VVdH9v4Ev}Hx3yY zNns#}`OH6SZKE}RYH_|U2UPc^T)LR^B{UUaF_d8GlTN111KDO) zHM>`*Cd3E6=u@?&OV)Ajf5Y%ci3x51GPzwM1&sS1hrVS0pou6ZydV%G4CosAaH#Yc z*GK6K_i-3DA27=tlU@k2^c)xet@L9ug0%^%{leRo?B?w7W609r&6%jwmWp^kE8O9{ z``lZN=7>%yh+H)KCHqKSIS)#X(yfg75XnF5Vr&rt0S?7LC*GNECo?*Y#eBLi2h$V& zU}Ple@D~M7X!OqZoI3(<8izgpcs84MOr_{a_PGeB zOXZDSIVwrcVS=~Qs;tbyQ=wBU9%S=X>w?es1hq$cgM~>}C#IzP*SVEM&$1guH~TZ) z`5#1uWAi83hO`RAqz+$7R$%-N6Ign(Tmvie$U{-99hOgLPD*~e%9ux=@7Xo;%(p#u z24<%yA#g{So>T8Ziz9s2YW`Ftai862x5`beT>MGyb!VHFpdZ^Tf6Zl{uYBEX{H#E) z%ggi;YYf5sl`(3~V_!Fk&27C%j72+#V5Ojr|)qJ(wy?}$CqAmh9Z}3Dm zpN#2flCnT~PGQ_)Dy2w#^ZQ(2OePl<+0V)q zYIL5Z10vW*l=Sqmt~TpJ?`Ml5WT`hgb;6&}ce=s8snV!VG4#9%Ki{1x^RU-uY~OO| zAr@LCA99ZG!)-qD`*3H^=h|D$Ph>e?nN;>2$MU2r_KU|bM`vX#%L4;yn7xLy)PZtd zZJFluV4wX;o>SZSWRCxhyh6l+tfB9G{cy(kYU2=mSYzdDF5J;B=Z)5m?ee$q+P|Ek zq*-AtXq@)nBn2k2j$L%03XO`SXUmi+(50pJb)zK~#-R|fBf%Lo|tK|h}o7pIwKDK>M)@wn%rnp zlUQ9)UT3G_4xvi>Sq$+vt+3dYczLrE81xU^o5e!3^y_=3NFE%m`pB~akOZ}}N~K^& z2`q!7K3vC7JRVo1UV~CN285Q;t8xnZDVlckZqvQQJidKlH$6!eRn0W|9+0<3Q{&bF z^(__4^#tH^4f8C(4}Ww~uhz1bOr05jNFlSfTu-Z$bFNbJHISleBV9ZeUxLT=q_G;F zetLtih&#dY!$ymc(RZmN;qiM1yPm9w^}OBj*uHor6vc@qji@Wd*(uE2V6RWTKw4Ig z!F{U@47qznB(HbNW9>`!sB@|5iRBhp?DgaytItdn*BN+VF>kPy1kg0yr0849ln-Uo zc;J<(pPQAdHi$RJ98O7S&ph?K<;yGzvfHo4Y4RhqrgmJ0W~lTauJ{SS0Kdu+(pQ=EU^f@3I8fgqDf9&;tm_1UoR zek8{eaLw!yk(jCo(r?1tZ*_`ShPolh9SqTBJ6&!mU@@PVqyg9e;-OlhQZ%&8=c~63 zm3J!hwgPH6K=07DRBWpq89`;wP(SgmaS1K=?%D2AjcVV0TP(j9UNGZ5aR}r%yBtO? zI9RAF(8FRks+UG3uQE~b=7a7Y8S~BBE7qhx9FriN^h$ap`G)^*%%X!h2>EZ`Hm8zn zBuATG9qyN2fV31~LJuXGCp}W&78kvGeW60c223C;rWv!Q7s8`{V?XCGEN#6)ODqbb zfJB1#oTM(zr9Tv&x<{!fa<)w-#hDXoQ8iXgzrif#mR||65w8cjLX+j0-C!?bLhc(8 ztYJ|NRs1&B9pguwK^zcty<0BQxGTy@Mv-aAAl3%c7{j_3Ski=6T()Dxu=%x9h=+-3 z3Uwr$pxgV!mjRO?>nN8zJ7dM;S~+c{q*V13g_<|^Y*Xkk3$p1dXt zI5Ks`>0UxC{8X4US1~_5WjsWC3zyXx7-Dj9Y(I6j`mQyv(je}F;B^f^r%{_7agsh+ zY)IiZR-ndVw`q(-W`{kSkZYR+&@6RKo{uJLr+~U$BW)Dz4&oe5C*WmYc+$IQ%5w!) zf=x@i4!gaC?CqMyH~u&VE8_~K|k$c0uc~fb2?w|mOJUar}{z~DcM9`o z8|r{D0F`<(uCDsN7PZn2O1ZNbnZ)bKh;us7EPLiUOV@br_3d##=jK*jzd|{_qPpWu zt3B}7K}r&()PiB+mrF6{3stkbV^)rFM=5iDuzo@Rfon`1Ab;NaE9e#G#2mz^*tVMnH{b+SU&f;Uz0iS=8?_1 z_vvPM9mGMteg1Wax1=2Nd9me@vtbj5$Ly2s>QkILhq}&JNWRru%bLzgUQkWQ!lX)! zJyJaU&OA@@&?kOx=9Q@C_dd^HPD5I%cFlj(dV7Rz;_AC7jucbn7H zt?Bhl5u{^9yyhME2CV<5%brz=Lq%aUw%CZ`_T4L*{Ak_z-@3_J)$4BQYNxlIUSlWn zHB?C{<+{l}_%4~%3g1*8_^A+S6d7Ug0}6fSrts2F7jBynpqR~8o*T_hxwDEo%{afQ zHU@U-X4&I|PchxDPXCECy$W6lnQ82bYL$H%KkkN{R+FDh<2*EWt$52|OEkM@E82Sz zze6CZWBN>dovi^0Og7RGWDKP&5Zla;ft9-^SEY76m!LktDUF&eRj*p>q zQ;SamPhrp_A*|lYp+b#`#U+wrWdctw{X*U&PU2gUjNhi>m+Lvk?4#+3a+R`7RgwB9 zB^m1kCKaX`DF}R{V#?PpyM`H*UJ*v!U@S1k1py<~Y6TN0hbL zbU~Z$_TdSLi$JB&3hI2^%i!FZV{@kRx>)YA<{001@8#kA6qAgp=N*7_HGnoVD)aX8c)0t< zkTJkTeaf2i%D2Ag0}7A`trbHGcIrrmp~*wj7lqb0-NVZK-?gD+9LNc)$KrzhlJ!S$ zXHo-Es9%qyUY2V_yMtaVnYO6%JmUFI>7{%(vDPcBD7|M&SEbyVAkJlNBqmEfsAnhk+%-&GvcWkTkt>3kfe;rugNH z;cT-!Q+sEv{ngl2R^81^^2~`PS%5S;pg^7@C7<_O_i2$#b9U|{4@#M&C>Y<@*CO1{ zlzlF;nwtKKg6q|M!rm7QmqS$QE1<9nykdqh_<=IgDtg?Q-{QW*U5JnIhP*Z-_cF58 zd`ZlAlC&V~ialMZBZ;XEtYFlih0bW6N8G^$$9`UeV5paFswj61Pbo>f1hM6e;9@Uv z%)%_q{_Vycn*R~43>pVlbDf68sHw1?yJ)xXZO)i`(tckQuj#sMp0m9+yH{%PIit|eoxG8m}Sa$hoN5E#in)`ztwwuB6Zz$ zwWyxv)rx-i8H+B-7m6ou-t;Es}}i-=N(jN_-Nhl+GfxG z)`DhNcW0ELxQwQ!ZEgH!QUuWy2iD_SRQcl!Ky&S#VOf3>jQo3MB%c(d$ET_qI^o@@ zJws-Nr#w<8GYI?kV2NZ8?X3f6BjzN7Y`KnO8OK$C>d+^*pctKaofB!(x4muFoa3E{ z*nDxOw!`Gew_kXN5=c*(raW5~$H?pKb7;O{X~zq$kI7Yt~=hu&3C!@j%9+L z-~duD4Lnm&ruU_6%q~g3*Vi{KHp}%i!z*zvUHWRplgb;9qgJCSArV_V;L$TFNmJ9s zl-T^X{nbsi!VLasuE3fRNuyohjWwtHVR~D5WrP(46U=ZRqVjtf=1w<|ub&8WcS z3$q_)*LZNWd(fxvs8TdACy{5LItoV-4(Fb^np(TwmMVf7zZ7S3dVGiageB_1mAr~m zWA4+6M9~L(_8yDJb2z-&uaXwQ3JHh78g`N-u2Bl{1{T=_nBO2b=X53L3pa$1iH&WZ zMH<-{#IaM_|A!Z=ShTI_M-jJ)6^2XmrJhvR*8@ApM5v+{1f{i`&+C4g|BYaPv=Wwy z={Fmch(<5lYlk@^9=Ch=)y}94StKUn6Wj?Xd*InNLvI8x$7Vo5cpM|Rokm(pv^BUZ zlKwqJ>)B@AP0L{oRLfN^7sISF-^bF}GB6l4^X0v64up!Ny?i%* z_`{*#JwYKO1%5qUYKoG-c1hf2;;g$3xHa^p-+pYP`sv<}11_xE8lfzZN)m>Ex95sb zW{zRIFi+NU7azC(y3reS}CE~nalAR_RCv`E0On=UeCc+$vLYCwaSf%l7KKso-|U!i>NiE{wvd{a$eB zIUOor%dgq?19(HWHTtJ(HmKA9QE^1Z^49&XBIxqB9j)+KMBUN?deEf~U{gZ*ps z2j;z3I@}2S=A%oddAkGI<2c*9`JX%yyps-}?3|QSj!C+C2x(W#7C80J>AxIT>rrf! zKd!WygcS9#!R25}#NltYPJMBP36rAp{6*Ft7kU02;AL9vCU48Bi}O4BV+&Ss`xHr2?VKqx2{Y&8{5fUk-`o?3R|Rb7it;xf zrk-cr#hr>!V7?r(1+SpfD5rsnP^p97-y4_4dq1J=jvU)AUszGzLi8-xe#Pq9b*w*G zZi!p(FZS38g9qzTM!e`CKQ$~;SnpOTpEV73(}#Y1{f5yQ7Py}jB{os3o(T#Lg|tQ{ z(IAoCD-x;*yFQfikcD+c{HS|ujeCSYbP|8CBHOtYW1>k#0VJLJX$pNHSV_?bkzEow zkbhWw1$kKh-Q$6ddRboyI=C(NiVhk$c@iU`nq@LUqeCdL>LZ3NBdyqbaB^G!JvCEUznZHLScf)woE=5`A1oG(HM-AjahdZg6PofeG-DHcK6ezdT1C*wi>_%G_m+~jW&!D6~kx5H^pX~-Md5X#({)|>*UV3XGR}G_d@v8I;%1f6^t_|NX z3zIlWb`Y|0+)0^xkhs()^2glYCgn_wWn3kJdcbl!iCFjfIxoN7*d*bPe;@}Sk#*JuJYlTQmZAx{Y zKc};Ld{45cFud_?=8$vu1Cz^oiYN;DNoIp}c%_1AwRmM%^J1WY#BZ`tYj}g`rNX@1 zy!=_Y#@utJe%-f(NJLLrEI}Y!c3#J)sjPUKkUkoZC${_HCCb^fyOudAg;uxR6BEyf0aIawKqt=;BejjcWyeHwIgM@lwOW5;cg5zTPOy- zPHBbaP_yl_qI#^vp{UN`ObNs5&9;KH+dz)rWUf%nG~)WZ!5kjDck1VsBxFIO<>pgX zvl}d}T7!TT_Hcy@(u&T4x5|uxCCf?@A!7N^YXozu1Qidxis#O zk1oW|V{m|>3*gCC9~;NRruk4tkhkc&=E53Zr$ai+6) z@s;09g{R$E%(y>(Blyx_HeQw)6Cd$Z8s-7p*5vCWM*fs667usBQ5CRv>>mJrz=_Jg zCL7}hCvqEw+czNB{d7e0w4U>>JP^l!N4crVW{S{(J@68Gb1;<^A-|T>#aQZheamr^ zV*C1ox8r((r)-Vx_Fy+fPc1m^`-^dh`uPoy{zcZQgm1Cm>naZvj?&BGe$VIj z{PGF<&>jjnnstIQhZbGne8@a|^es8ogxqh?{U=m5ms7D7iY?#u35#q#m6=Q4-j*9L z=6qpJ6cf|#p|9^7gDG=}F2-d$huc{kHIANjv(4+!@sssjxmXkt5)s#{(&PSJVPI`wAQ&(}S2fe$|qQ+zJFT*Y|zwRm+N-#AiGLD|-#g z*Hz1d$T?mbZoexvjm;`pht1^9t)(@3R!A3`P$5bU%f6PcF4@eLqlei`Kik!M&RMbP zGpQ+G;HvJzBO(gCFrt!7_n2JV;yzH%__U78>BjCu90z9yaT5S%b=_8ren+I}dH*`+ zzMN6G=Db;&c0PtU^I*BdJqYgcS*`q1LGJUmxF1PdIut;XJu~EP+k!eEi?{NuiTeg$ zNTMt_Vl`VzWZ30~BDd5J_w)D5tJ(C=I5aMW_`L>(qS83!8?PJ)*fW?;q)Z zMS`lJh98k@EPCa513>Gym$OPW#$=$vHrFRm?X5Ed@yBcUJW+ zA1h2fNn9O5wYU_Z@IfGYxloJbZCfsh?P9Fk*mbVfAoFUTf2=}>b37<%_jHVb@0BtBNX1QsJW^O9^=oo&3f3S zU%!IUdBF>YV(*D73Q#Cd1J-y$-B1qRow+ETf5mxDuF+}&^Z%8!>GhRw@an< zOJB+77ti@87t)7Z1W6X%4%cf{UFz(!?|8=%kHR8B&0eayaN!9c>mq4=0ACoa{VR$ok zfzM24Cs^%wOG}3;$+h#}&sIGPM>$voDbqMFgPgM!H`XqiRFss4V(||95&5!d$J!-> zNN=YK;}!~p%^8@Z%AfZnY4f8vneeB9TrfvLF30B%j?HD1 zm{JF%oj-)e_w=u>aq8l4iY46sF&yi0M9#?Xdb}p*5i`Zz!xC`9LryNJ>R8>BK9hRD7tmk%UPj?r;oL>F9eAtCBhxSDH| zuYKMW*GZEO?oK3lI~4d{RocxZj@!n~TlTf*`&Ym;$L702&rUm=*}6~Jb;r50flCqH zfa> zpHDw&b=#z!O<#rsvtiMVI!DG_8a$hh8ZEEMbF9UBd|z>;6pr7G7(|r$^g=ti8Raj! zioR|SipNrv!O#%0lQm$bZ4bqI9i|2+jkU+=;f+N6n@b^@(f)k6ihZ$D&aTU!*dias z%@2P9kyRR%_1kxvJ4E>F8*niWXubA^Ex{Q2?R##Ux9w*QJ`P^lgS7)OF`Noi<=giL zqA)7lDlnvo{d>D^u4I1S5{fpj$)XFY-rBWmHkh5zJ^dD(1^+F?Nf-<{BCXl_2=UsO z!~rCw@aTkO>LulNW1TZb8L>@*<3drxu;G^X!;dIrp@{pEFaminh-DvlL)8lM7PM{u zj$4w`GN6oW--sxI&_WP;rGq?Ms03=|UgV~WW?J+B2aSf|7f5_x)z~=(!OgDYC46}7 z?Wy^T$}r%B7pVu<1d8RL7c>l*Nqi4 zeN!z4B)qEQW2|yl?^|S9Y>XbZ<2R+VITyap${N)fV~c4Iu0udRr0bHf=lzK|n{td% zKvO?iQlAK?D+k7Lq(%DmS)J zle`-4Hs{zNS9XRb*|Gcn^F50>mAD)IlaoIq2&lKh7Z0EiE-!{j;xsko&TXaP?Yfd8>&(-4Z8+1%w-?%GwyQu2 zQ5*9HVrs#RQMat*GwzC%2Abdr?y6Ht&`uhCUoKrw28!(XJv2W(QBlDdIaBz~=ru&A+s!fj^-(G!FNy zsbiJ2G8Y+1Ior&bX5oUdKE|Mik~2NfYbr7(F-6~(%WMBJY|YRpkbF6pKKh&MVzVue z7o4ZQ>H1d*RI7#3Z$(BZQt>eT{GXjIhqVYU92<{@-+2p2iN=ZWY7l@ zWdLsc&~!2fNhBRhAW*$D8R~H$Fm(naJWQ$_+RZn+q#4I=XFD0+f4NhD7mX}T{@StJ z|FapDbmh|$AurL-Wt6w6rb~OEIR+y=fwGWGR4pcVB6cm#-C2j42Ej=)^lwvanVXK# z{e&}u(0lV#3nR}cM);bi^o?&tR&x~x%1sUxf7|PWuMtZU)sN!f`70~2E4<4`gwRgx?#nIo!ggi;S zuyjnvBA#yT4HanaOW$q?+*5j&YWDHFx99d;kW%f|?+|e=<~_cBLXsc&G_z|K9cG9N z(w_lX?k@!p&_OlRB@EN*jY*=_asxZKU^k*gE&Cl#sU z`u=md8;_Y}E~hUrcKs*Kn?~Nz>enH>5tvlVC&sXfC1=RuQ8?qTV1 zB7D8Zu#6T4BpI|it&$#uaZxb^%D5P?_*z?AMvEeH`F7HHmpC{$2$x;f&*E01qlHwy`i+gIPKC@c@} zIJvXpf6C=3+ivS>xNEc$SF12!%T*mGPO8@La)#CH4qLiGgR~*o<%e@BP@WnR%VA2`AOcSy zR-F2{Th?=*-Er`?cdvQW`@4mD2_?@0+9U>#XNJGhM%Uy4&A1XtS??9-V4ODj!wJ|k z%up0po*dS5jS*0KsgKgxipOMCiKz)W|9+Pt0!Ch8(J82PxO{HvQyx24p)1DMA#iHS zdZ#oTn2_)i$VX$kK~v+2>=t;4(;)LN2dzuMuyfb)hY~E~rEm`>8WZ#p$BH}n zM?{AmFYk2GJX>{8@t||8@(cs9MUWXUl<_-J;x{S>(}nU~5{$#2)y61l4sraiC(1Re zn}4F&&(0ZhqxZy)wC69RM4U0bbnmm=+gLnKq0wl8Jf6j(p4E4-R$@=EyZ0q!+v4blyQL?5fZ=l$$GX~4ioRZQyQ|kyrL=jP zJ9DOpwBbEZbbe}v1^CuzZp;+{rz`|7hR{-reTgLN(a#6tP%+_mg@!MyKaqe)!rd0z z9KxB-=WDB0y>_gSTu0r{)7_(*k(vJAyeDabRFbtd2rH5Oiaps6XUh^`VT_wWfjp@M zNFDni&l(#)8 zG3ck&i?M?9wBJm>FaSW4&{GHI!S_0EF1dY086;lR)iycz1D1hIL`Hsr>yw5TkEA_( zRFfbkS;nbCbRJ*VzR1I-p%aOcIx`Xz*q`#l4#brP^h75Scd06&Fm_oyd}$Lwv3Be8 z$V3H5+^}F^s`_hH#zR)j!%~5u%Qu- z?zh2XHKt+K8pi`6UUQ^)uj3+=`o`%{_~W!x7S*i$ZlAZbVD!r4EDv8!A z>V}>($rns+_F?0OHr>{IhTsVJeuLxEKgivgAOzeeDizWHAI8o)sIDMe_rZd@LvVKs z?jGDVxVy{2JxI_1!6CT2ySoQ>3-0dlc4pqIckkS(x>L3PImM|0c6YB{{jJ~ndPaO6 z_1x(B?yz*izU`9IvnhXntv)}t7j3cZ^6lZid}9emuU7(5cf;`sPhj285MmAhL$$3y zNul?lyQ4{sS0WkILuo{{$`VXVqOCC$%;9}ArQ(`2lZ5Jv&#)YW$FV}5Ix-D}h2TnC zE^A5}`LJskVYak=!EmkAUv726%97J} zPk?2&Ov^Y~lRSGQ3dy=9U>(wv%*q;7Nk(I8mB=dK-7`fXD%yJ0iXB~|b*S?oPly&1eD+fb;REQ>!`fKI&74<_HsWeXiU z)(w_jPkvOCA1{e^t7PyNj82wBmQj`+FbnQ``H)G1Dlzw0qB}h)4WYIaD%u;|T4!vW&gAhR9AIr#yz6mN>oO>h1j}R^>?dG~aT9c`V z+4CHXRw(&A#0{gedxP7Q=)NoX;6K=;oyafgec8WXg7gL)Pu)&4tCkxuxTf2!-FKIu zJ;)YC;UjI`rk&vyE}vOrrYB}`d&yBiXAZ1+z*3lT4%8wt1U+@~37Kxm2#R?77u#Xu z45Chaa6~W2ZX}JViU8l9P)o7-jMUgty3B#GVCYE+-w$tfmvm0-Ml6;%n{NFYU|T94h8G` z+_o*y)2tTF55`H^j{n1VY6w?$;?|ln6Lv^G(XFlY0c2^>)^T2G)W{X1q z9hSfgWn!rZw3+L`*JftNuG+af*ch)bDBpm8akm%(Jt93*E8!WfR{wxIL`T?{I5;tq z&+ni&Yb=>XG$b)VFNTmGW<6B)l0D0Ze`Bv@W<4bG+rH)1pW#o~57Z&o_6Tz6aQ(~S z))$hRER{11d<7G_n13iuff?6xpTnj-por= zVnnn+0x7ELaWc<0yIoM{ZnqeQ;?A>Pe|kdO{Y-%rbm)l2 zSGgFg(I$j(YpXLLMKO9MpzE>Nf>bE)e6BfCu@tnUdf=w)B(5aVoOWy3LgwIA=2q?U zsFe{g?j7~;%FXem6-}L)dewjdSL(O{WnyCT0}j%|L*_u0^nd>>xGNq1mr}uKLoMdO0qp8Ud9rj5WeB?9mLz@%#+{aaa9=7 zJ<`5rLeUeXXy4EUL^K^mLLNp{t+gk&z7yX~(X|Zh$)$0>jP@R3@IJNP9~68~a1W(9 z@g!ume`|BeRI1D2IMa8&oukc*nc##)82Z|7mbma#aqA#P9cI|e*Xl?m3(7iLS6rZX zlk0^S5BiE^BP@Xp`dxaG(ql4yOsTHc z+Q##sru)bBv*X{&q32~jcDYl06+qfPs(+G^A?fOqMp}Z{MuuER7V?FakVSw-0Hug- zg3cc}S6W{0ck;8$^$pEObivPe!~ z;SIl^K$68*f!mW1Z$F^1xhP}5ZRuw2mtpZ5Qt&9)?b<@@5|4D@|CsSCJ8mTi;0@zm zwpQL*u_Hr4OD9L^k(S65EEE{$TSZ6t_0y4y7p7;guPh280wT-tf)J|O%aSDv;(5QOlC^Hv+AcM(S>U8)@vIo;mIQ`|-D=h% zv7r*3>|0C~ALi{Iv!#4iK#@g-6eUfc4(Dkf%bej6QEi75mv9p?XeogBW_$cLUEiq8 zq=sJvPvqo}f}Ccz`Hg=7;Z>nB!j*!=nQ^Rcc+0<0R!PK(>0s%QH$KV5%qv{`2e^Oa z2B>?KmQwX{WiDi*t!lPKVdUVdK%RE5TXnTc!}C=@p{(bOC!4$^#MMizcQ-<)Syr1n z2)|Q!SaS(M5hAMfG2{2*n+v>$1!U*Xoue|HDRRfbN6B-+iP^H6+&JWh{l4X!<-wQl z4}B)7wJbm9&teKe>pGEq2}`>oq-A2KFNr-r&`*+3rXK&|f0UyGQTTJF|2!lH$@ z)8(rg9bcCoV9^i#Y-DGYb3NIQ?p(}|Lu!`_=N(7j^XPE^P1Nm%{bM6fN(%D8;!*wa z;-X{UQ2}J_vDJK$S;l@{A3PIc-fQv4bk1)GMi2^5x~<%{W7(f3d&Ffd_U>yBq&aF{ zy4MNnf&;6Zd@`)L=U>$C<{g;4Qi}@HCa!W()YC~8f-ymaT1^M$k`#Sbt@9%N{G1gf zzKYfK2&08(4n2%k_SN{V7Bp-$$n2QG>#p=oo2ne2?kocp*3+2bxO8yI$o2(?v}%3)jL68M-K zb$Q3b@U`Jj^X7{z*~!>(@tLL)Q|vyNC$p+`O)5~Idi+f&5BpS*91H~->`I?ALQ{9i z6zI4R>2%q$c1fuuYy0i57~DXj;T)EtUasA&&a!~Z%RN%&(lME3XMCR7UT3)Tak<74 zo`BaeeZt{@4#kjJU#=Qm{pD}r1s@W53^lCNb?`*LP_)`DhvE8dWtg&siI0l-)^52? zDg+Zk363!BCbjI=p{_AI3nI7}_Ggt*^tb@0KJG(As%UbhTP5QjNLD1L(v@_|{B(Md zFdtw(N*KpnXBL*NJ-v_SN~sdt#&*meYbOPA>#gUtdtLE^!kq~yfAL!cm;VmNe6H7u z!V3oqEfUy2F#b(;GKzXZgjR$X3W>MNQ(?eY#y3sP#9o<()qvZ{?QHbr7mZ3m>|OM+ zw5B2NR$B47=(Cf>&8T#uQ{ND2*~Rl=LySI7Fs2U7`d9BEeCtNL*TI&G9f+%!2TF^|-sq4_$Bf{a!uTtGxXX!7$T#x6g-LH5P z2JLqZ#bY%8lAw>-(#H)|VbS>SVCA~|mmK~tr6uawnQK|0=36`)N`U`-az*EcW z221YkYdw)Hx86g!cKc71MqQce9PuS$&E8|*!tyP%)zmD++|zNIFLe<0U0a(BvG8 zZVqQGr|2g1)MWYDY^e-`BdSCYig#!naPiYm4f>hr-r{D?(-Y z9Kk{b#|jZ|3DpEv1u|<3IyOOtso)UAf7zDM>^^fkVMUfW`6yirY?~oaj0kIxdfHtz zp%%TgJ(XQ65Pce7ND_J%-y~?WyM{~0u1rTQzy>*>F1w@*((R{nLIA;*&Eoj!4Tg^L z+daFr1m~JdZbj-@7zz(eJ_JjhWG>}Oq3X9 z$&@oR7H>2=k63=o^Qm)HAGaq1VfC)ziT_-S=yX8awhu86F@7~?zwZZyjOAUKRtazX zqVf60LNmwgimO4wM?LXRJzy!$(|VH&9+R2;W!UiDDC)8`;)ksH#k^j=EdAJqP~`Bz zg|Gv?BbILQyheR6$B!Ql;q|n=eFl*essbIeCa*)ed6NWdv6wDFEzFf0G`3V_4{5+Q zfA6d3E|FKk3&5(`RLlHZvl-TzkE^Y;OK>OFR1tE@VVUn{>}_Q!PnUyrwkvLIE9|!R zqJq%!lR1^Fe$Fa_jZKvL>z1+#9CpXS$NgLw#A3K2iNfHzoNF+2FwntOLfv2br4y55 zSEoyI7I=U)3|AaJ))sTu@ki(1uqS0=#0oBU5uba$4^~@gh+uRU)@;`x%wixG!EhRXD2JpsWt~Q-#?Q{SWdK4p19M`CjDSf*6+9{H-F5O z6C^SGNCI5^iomXhFK5DUUcc?s9ZqT8%Hu46iNJ-?yIq=G>NA$0E_SQ}GBHQN=#kZDU3!x=)pIq;@N!2Q5_q4H0fdvcU4T}&gJ{nYZ@$NDsDa7XJ z&mNHtY$v#lIM2*2?5ILBfem(Wl2HGSSctKSZv#+y$D|k9;KTEOUJ%cz~{fbVTF)e6;ZoA=Qj&5}sOT{N5u`+h@giU9->+*st zy^XwGhW58*z-I@Mvsm}ljfs{L7J}Eg(D$m^U3|v|7Jt0O{W#m{>F!h}rrvj!+Ew8# zMt7jmenF3hIFl3x(vH>TK=H}ssIX*_yeyeK;rK{raEn90#abjpGjJa-Ov;XB1&J~* z?rf+J&f{YJ>bYiDeKC#0RvW#IW3m6HDzLiIn5-y`oSafo)g%EfA551ao3xBxv&w(R z%M;Cmy<`u0q>b>>IDj9Rml-7s8FgkVmX)|rGROFN|hrh`{J zWyL~-?jG}q-gAM?|x@>`n=Ay}|UW_edvT14LUCI6V<(Ov(F`-js?&3z`KkxVw(D z3Yl#q&WL=9wohL1R_pSbmz;mOX6H97)8Af1>*E-a!A7TIen6PZm5eFie?LoP`y9|W*U|M5O#F3_(wf6#&qafzlK&z# z{58|6fe5hhB|_^zhpQ()!I3d zq~y~LW@NyqP&lLyjp0svitZa30}B@t0VgmctH0A(jS+1o^9+vHDI!!KvkoZSS1FV) z-gvcdzcp*Qchu%2*`mModV|`{=hg4<9Fm#B z<&_r@(WHL`V}5oD9;U73%*v5_mai_{7Mo*S@7^JRk1jaiL(!(zfE0{}D%wg6kxism zU|KPKvief ze65is8gEl9So=c9-@?Cp5V-mYdB8BDlC>N=EPCFR$E>bHC9Qi8j18pD?MXilH5Ia{ zC9c0800CXRApBhaT$#5%;j*IKc24oY{AyZ3mb*ySLQ2u0pdFeZK~P0QE?A_T=7 zLaT*2_MH~RJ7y1D54P;yVo2=ahJ0ZJCS6_-&KAVce7n-_GijNS;oMQ8o+0+w0Hx38 z;bm9Eo-nB*=G&7)bSil#u97EO{A{gt5A!13exCEGY_iuBJgiajvd!M&J9a{+;I-#4 zHE#CHHFq{kA-_q_E0aR1@+v|=EG|Ipm851skr85sl{g7bWVhDv^2iYWDo<#Fl5`D7 zjKvjWDB|5W1hhRP%k|d9a>orN)+c?0?8`gtWK-3$>0EJ$t+|O$GEALf$bzdq^;W$f zqrQZ4X*ayMa}av2Dw}Z_Lg{n4f-;ozWt3tF34bJVEB+%bSn3Hz(qV%9whf2{_fs~5 z`pYNK1I{?p8IZwc+4f$U@bNZAKd{pIP?Bw5pGGThUhg+b%Yde+->=B4_E@|Hmq9SM zn{g7O@GqiOiX@ZRPMkpv^S6L##}XM%Gig%K*kaX^Yx&1^rU5aOwxqxXoV{%sU+e)d zmfl5^nC;V4m&Ql!>*+FR5%@3$6sj05ecwllauaaEJi2pF=BQd)w%f@KzjwKKkM@8? z#A(eSm-J=3wbo`&&z}37t=@0N91$4J!I|d^8u{-_Psl!9cA_NmhTl~KvIGIWJ6n-? za^3gXmQFQ}8`@PxGw4!BSZD%|&!mbG?g0nE&=i$ze}CdtseTP;;x93Hwy8x|7m{G! zhoepHqiPiH8HEbdYjr?b%)=pobGJEy=1SvCq%IgXD@bE7(EGue2ZY{-u4O( zx8TJ}zsWW^p*g?SRiZ!G6-t>b%Ot4a+j;EOZjQ1g;v1S`c{ys2tfv#(K?*i@xc=2r z$2wv>>7D*v{?Ciz?~V!5ns#FJ#m;yrCy?Vt6(60%P+5A+h5x4D=w=*IsH zsW&zrRplxAa~<+yM&sk!WkRrLO6HV* zfLC-`p?1cs!`uyE)PdnFI8%x}5AT`ius6L=WOetdj6C>$(jMjkcq47av_qy~0Z=BG z0$B8l{)8>Gf@+M1)utefnz!?_9oV{utat%lB=Z=#xWR_US>Y|)#Il5ruUtri)64eP`qQSw#i)ID&iGrmA(yx&+4=*Np%tOGYy-_@FaOIzvQ5Kl7$ZoA# zd#1~LYMOMma*j|KRop68>Ua=l?sZMxN9vf#3K%Zb0qmgKsW!W{VkM#OWNm&~x!d1v z#6^g1Tn=WF#1{=1x|sqW%hNwZ{BEDoMeU@ZtEG|wrQ6Y(EPPnPc10k&E$LTKzYNwT zqy0Hk&mj*7t6k?S(r!2aE}r`(7rWi!_tQ3~@e6)#=VcI!lF|$i*`Lt0SUMj~sK7V; zXcCgi?>W8Q|2=X2;qrVss@U+$+tV?VC;i;DG{E*-jwDD$lYc=mTgMm7oTv_HU~iqb zUTuoA0s>WX=GVPyY$_D#y5ueUSaY;*C{KsaDQwg{iEY`(2V^2Fn|N1$#tYr1k6#o~ z)h|r%=bsQ04Ds%_|5n#&Lb-kc#tH>*vQGiNf2n`#g~B;S^>RHqg$q9a2e=>;BisAc z_@&^)cCWiCcSPjkk9nVhrK-M5kvo!=Z7#E4!U`f8ps@RqC)<`VdwXx>a7|c%^;%fn z86zhd%nyF*L1y_95BPlKN@Q|}I>&GNPkkzY#!e7rHXn#4kbHOgf)Ar3qF#E&VziBN~cSP$CT>D(BL2QZrfv z&_ahUvV17(Fu^cXk0wV{!R8-~IK{})Cx57lUdWxeejG5MdW;*!Yx)=_B?K<MMF0JibwHFi|reXia>S}{6{^SPiydsaQS&v!>+&nj$^?X_Oq~P=FNNpwh-m>R{ zuKVwDL_EzFGw|%P?H&Hsql>bwlpx3!yB1V+4*TW4Rt?JH1|cj6%Q2}c!YwF#azH`* z#JRE~2fH7DqR#p(VD#yqF1i2RJkJCh90Z*d)2@85VTCgo+Swro?^bb*r<_8%Yaz6= zUo4bh4=!TFI&H0=l{dG-cRuMt7Q;CrbH5Vg3)^2P$e$MNX{EMkGd0>5=zX7$4tO9V$`K4&8$&~7M zE_wN}p)3J#6|&f_fp8(7XWO_Wq-FB@33mFQ8LXx+UcrqLH1GdFR@M}UQ#Ip7i*HD>yZFF){d}&f{T_gw2uzPiq^jxb zI7y;xI*)RWT?#k*``&xOB6^US@E1b1Rq%nOmhqjs5JCUCMOeCx$(Kw$-{wx*jz2$S zxFQ!nx7_^R*48C3{57Yi`s;e8jfg=RPM|AoZ6V$5Y_-^hWy1G0^S;FI<-t+Uov?~8 z$eZ(;msY z;w}0S{{Z@R)NjD?@9Y1bOWc`IN+89kzNk?`%croWM9F?!$r986`m2OH>0iYU5%BoJ zQVa3+F2%P1PmqQlVpoq6c>ex(xh;plPZ5X0@S?3$KgTG`{L4O@-+f{y=ge*TM?Cxm zhI3-j6x4OyqoUw``4kuKy$QKP+Goeb{w&;#4r{O?XtSCgYqt71YNF_RrLq$>G6DtE z0Qw0F7qY}`30I*JU-t_FnH05{#i=vnje$YGL#MbkWS1ps=Jn}z=Z@dxbNnwuPK(@b z!1-(@@bOx0MWgmUSsjT)QS&6Pg=IZc6Ug04gmL`jGb2q| zH>_!)@8!LwnA;n1j;um^z97LvI9`AGFiGv{6OgI$b@VZGr*UIKs2Z)61g|d^wKw$G z$muHG(6h)tXDw}%DvO!KY)v0f#$pezS8Mgi;mcf`3SpukEtl(P=34VJrFUWgPV3Qn zp^PFJd>Y&c8Ad?c_c3JO^N5F#fIyBI+|r&;dl13dh)NqcKXpH_dkjTqoect}22L}_ zFfizg+1p|XsdxR@byjm$bL4xJ4sK-mhrZ*+|Exq4j|IF0nByH&Gjp0wlr6~;>Kd7* zSVOz50&-ACrhWWF5VkeH*X?$!**YU={Owe%DzYM@iCRAxGk8xVj4>>}kYtl1&QElk z&=MBPkRrA@{H4y?1ZL@yskTk~A#cNd#ogimHrTf{l0_Q@pam9RQ4K+9ra z?|aY$oU=XElKkUL80gr+dV$IQKBXZ_skGj0BCX=zZxVEIYw$g^Xu@?ti}s)4PCF8H9?LW?KMd=p|_=B>p-uJuGtE4%_Ln4F_DIs}${ZLr%LAKaJQs!`27 zIOu+rIlMXGZ8S5BONLm^eZohAfxzgAa80uqh{UY}O&cYG@j}5+vpr;N*QhE-%(ra~ z;B%9M*c#$1Nn_QC5%E@~$0x{So9x&Y$$G#@=oQsR|GEiG@}eE8>XaSvAa~-IY6w&& zg<|^hX9qu0*X?dvK4lKmN>1bfo&0xwXqf^Cu^tSDHgXDEevcqF`Cg|u7nyz$GslMf z*KbhUaryp>jyAmF18>Kd7{>VJcNI=HkJLP^_W9+XKNmJLmCA`9-NDEiOVBN%z~W1u zueZ;@&`~es{g-s^6I1UdO|QK6;SvE;;tY=NUX!B2_pcp{gDfZnDjS&-7Cv5H0hvKG z&C6IHzj&xOe;B5HN~Q99f$j9Z$B~M2x1VV}4A`4!jd3M-cAP(HzPY{mY^(%$TNoN2 zKcecN*qZl0_DgGmxl!%Nn?ePb#Js~9LVb%!vzKNQTX*|vye{P@xNDx|!0UdJ7fw{u zO_}%$%m9#AenI5)177wr%tA!Go(X{Enh|cVl7Q_pwrai=COO03c9{r@&`4I(5o$o3 z^7=&nEtJU46qR6e>kkqVGDa`2JE!eS+TK`-=zG=HItI`zk_BDaaYY*{a1}c(gwVX7 z*?M?XYKD8P*YGe`>*$dU-JSq-E{9kD}&|MeLfb8McHs(;6KLrIG!dm z*s#?Slf`gN{uu#J$P4Hy+I~n;%N0r7;?;%2iV_PPRLL@PG3ngYCS|(-ZJV<{hsE(k zH&rXwU;RnK_&kfS7mxdvdf+7*rSgfT>IF5X zqgn&DIBJpzL<6lC{DSmC--^|0u@V_oQl&_rF;b97;)s*&FK~>$-NDRWl;BmpvLLQg z2X6UIl?Ju#~T%x4{=roTjE0Q(;%r)#4N=p%3Xk-uni}kB#=QoR47O5|qR9ozY zw3Wp)FTGaR1LJ>7RMzsV!zAK(iy{$SID+1bFg!q)Y^n;Qko9FM8oEkmIJcL@Yl|zS zC^O_Kvga*k#~rn6w4hxuW+_9IbbB6X$lVQs9R`E_aquQlLnnLSd%Yo` z#g#C=EsBN%QFUo>;3=+n;0b{2vs=%b2dz}7{tg6_L~Rm99YlBAS$FXFu&_F!`2uM8 z17e+|w0*h+(X*x3Sy@>R zSG*zbyL<1nJa|4{pk%)3p`ExXY)%&1{`;7o$GhZMMKwCh>mRga^yAC&FeNo~Hu&`i zaAiC`K9Y)CMx=AdAM2OKgK8;n#`e7n&YS93L#twb$w4~8mn!!vvYwB3&s%2)*N8hjgxZ!T|^G~;5)gyTP4J0rB(N;;rV`l z`80bBvg{r$k{C{P(F#BtvGe$IqxAH=xl#9ep>C?EAfWQ|2U4T?xjx z%~!Lf?g~_Mvcq)Ag8Bt@$CcB@h4B_w5uz>+s;rZb;73e>E&cc$r1NXrUgOs@$4qJk3iOzYm(KG1+BX&o{WArs58HbYw_`hh-z_V?m zs~oJjN^-p@eaILfD%7clrsg}u2%!Mk=+h{96!-HtUO%BrKzE#DIGxjas$SG6Z?D%O zHuDJPO0Ej7W&vQC%Hnjcq{?SpQKIPp`QzM9zS>odWxaw)7HVG0e6&_h$srs$ZI;16 zPwmw!SMKP5gXE_>7m4GQ75G1&OcE36H0a36=R2o{X`7Ik3%<2oa^Oe0AC4OJsv~-~ zjpypw?Usv>fk90Cb=nmFo}}H(xvrSR<7#7`p$`-OJ*DbLVm10ZT`=aDWiXmYfqQBk z(7q>(r`cfBd?JTj6I0I-E7AS%KR+g-ydomU92{4cJZH1t}pC5~uf({{ZTYzZk!O0pZjR_hZJg=sH;pGkWH* zN=;W2~5HH0qq*PY`(vZ2q3U`L&Wx1Vq*C3aL}!+7Hj8XaB&KKV24Og zdFwQS*~<4Pe??8MkpzWmwqCgM2UiN=z%wv^LdNHy&&XsPME&T9;wuEl`nQarC0V$| z525$Du50ssQxy3~rE&-08yRYbH%Vk&SYZeYNcR>2=VOz@R8t;AG;Cj!?dQ^X5->*XPeaQWAauI z1yKOqUMuV~t`CoXu?x;B;1p%_yRlNWTrOyNKF@9^)ulB{k= zU%r#~G7ANrkph$V`8nVA%YdV@1v(S@#QtMY5@IajMLIX^850coau5Izospt<>r6@BW`5|Cuos^Xv#F z=hh~FXxLUz04Zl6I}P%mh+pwgA|FlCHM;hYCmcN8<;J`J5?YX}ebxK)5m8J+g~LX3 z(lDzaRqOC~a_Cx%d*hb~Z4$*EX|r`nek^Xd|9L4&YzEAqy{P+U4~SHn4ZC!4{R(L> zbJL%XGKJ*(xxu2ghuOzg)48|p;gIGNP8JTnOsOj3Fm#!`@v*UAn5vSj+q|mbPsST- zpfQJx?z#fqUER8$cR1I*XKG?rMnydchY1<4m=9@7p)?R#M`{SF>kL*bBY0g?uP&dD zJxz0q9d~;o`nna1zmNe|hq*+YgVUEYusT#)OY;5<4o*D^`tT$EoY9ZhySW3HS5t_) z@YI3OoEc;$fYtZ1%TTT1vG>Q>C+!#?>-R{ii9*jVS80x{PgQoz5l;f1fU&KBe3KS<=GBa#Bu!+gMB7^aa-*XN97hJcLk))@b?d|c@nK{2(5O`!^J zmL)1?gg0?c3Ty0mK@#B1dY^KP{`U5jpYw{F^CW6L)IpLz`fQ82K~!|H%q|H9v0*9a z(^v06+|fw^-rtC&hCk2O3L8aW(y$m((}pLW##x(6_^bNR%O2XXU~pNKMJU?@JzW+r z9o~t^jXIC*n@D7<(0@vC5Nmq?9ySRK4uKbjUAPdlA*8-V>gB1rvQfV0{H_bS_gsh5 z6_*!Vy>n^yZM%OSw+7Godp&d<`x7T_doO(-Hzwuv*j7gX+@qvxb`&L0(n}(oN>kWl z7BEPp$|Qvf?%}%ZLsuiS{)N0N+YJ{TWX-nM%Yu4;+VOlwXM0LU2%r?ng_tAoRANY- zQA9+StU&=wJJEuAu1BlGwe)C(3TccLl!Hxr+%6IbnWVuQ?{BZYBBF+ZZ5YIPuw!&D zh`*DV^huZ$HX;TaVeMLNy5X+6cLocyBGqR4*n(`!SL)Q;O{Z1}>BR+t>v35~QA!3& z^;#L7@q6-lLst(Qjz>x>5YYc5dz?4DB1tAd{wzqsWsA`Z_1{G2PTJ^C3wtzz{TWu= zyz%x-!>{@Mc^eNG=C0;^*L}2+)TTPixM)9&49T6ZwC^bjmI>N8pFpKR7PgcSVrcx2 zmRn2kL6H(0^UlB+^UE7_`tyfDTppXE=25Lsjr34L?lx@RYM6NyCdOlV+5FI)lTw% z#dc3>2!&BmEvSr1TN;eOb%I(NzmVkjyipt^bhnw+{_*&dU30q~a9qEz5q2>%`I3n0Gxn99+vBd?gN)#PZ5Cjy>k1n0ab9nm2W?d{t;=qT4`PXBGc5n`am#6Y ze5vcFZ1LqA-0-+3SeOZw52Y$b|C4eR5ez$P7;wU4%Yhevcfq(g)d#2?f&e#{N#_{sPD0L9@c**7b#pfLMbEZHp7N5luEJQDpp zaoIXv;Nw{U1QMSRfzKGr8TX-&=Fz0)4nJOLX=U^&7Kh5hv<>otlt1K`|N0Prg=D^n z0`q%bJAL-(EW#e>q}~cZ!;fUtEr~pu*-~{KSl@ch6z9JCK|ksOtA5ep{pz~FJ1*2o z`^0cP47hSKp!d9?PyeBmODhUsrKwTc8JG~b+32KMFUQ6;?ruS$jH$E26p)fTm)S?V zXB49ow_SD~xyxu3U#(XjN0Dsd&97|y8Xr+&pEMObs84P)#aYW)jQEvg_D7e1ll`XP z6GPPv>rJCkmY#ZrvRr*dwM<5b0>4@X0$(j;TLUPO;plVvyZOvQvFq19Ygx3P=Yg2% ztWV?w-EbT(=FNJP@n%26>+R0j8mEi)^*~x?B~W5?YGRk4x2K`w`fYP-4PRMSY)_py zc24hUjSKR%*ULuFpE$S0A&n3<0lQ+=Bzb~Ytnu&OT#cOTc$j+({r0_={Tr3`_O3Z& zW|IT;mn#z%kFwp;b$sD@oi{IR;U<%y^{AZS$UE`4i68}xtpOnIASD!4r&P7w8Hggs zP-DWiIgV9q))|g;V8Rg-wSYiKSQZR^Z9s>hmoag>FNlI-?NXNfNOBDF^kqI4#YI3d zM){m+cwv^0V35S9E1~nq8Gg}&TFG~}11i-hcM7)3A5C!MvjV}A8_m^Q&QR8?LJGI; zK11cCOZSvhBd}FEzNMLbD3yphf?GHJO;aKBM;ikRizm}{>gUHRA7udO4MriR#}tEK zF*QIxgFPN5HshJ3h#PnT455y8Vg%hHi21;K$n$0z>^CTL%IaC2I~03ZRUI|A2JJty zLF>pFB8Z5U>okv`7Ihg1X)NYgxNbL2J>4G9N<7+ZFQVWG*X=AQ`X+kYIjoT|ep%p; z@j`OlevG~odw%QpsgdBTuvuym*PEx~L9lg9-f#|2LcXWu>*_*L2JEoxZs`M`@0Bh; zK=9^0lS~k|nM=^2cB-jY>g6_I*(udq&ALPznufQ+>NOG7jZOlfBQJQl#Y-FQtn8&u z_kLpw&}9kN%N{z^B(y!GjuP|vB07Gk{Zo0AwoSmZ^F>rpuIcFLjftQib*Q3H)BaI- z5nRrx#m|%1Pvru6$NLi>n>FY(O157pzZ?+CT`b|2GUz>D*bI2=a3Y_jnfN_}=4=Zd zn>k4ev>O?TlLu;%y2Om%(Z9>cen{Km5z31&@;~WvZ-z~e@?-pM# zc8jcl64wkY!%#7zaje6nULh~W6>CI{C4muQM=qY zPd{teT%G<_pVnA9{a&VY-LEgav=(ADsg0Sh|LK)Ndi<+y5yv@7C2dXixo?8-l4gM@ zR3d|hFngtz#O=AZ%v_1i1WnS35AHFj_++KcC+X>YwjPco5UqZjShJg35=K2;ta(aL zhHYGjlc_ArwU&O~(2Uz9Ir_~YlFWpwVt=~iT#5uUJB3M+^@3BoWdhy4a{}8o#p==G zc0f9d+qyxU`f@aFsI0!hR(o(Hb>Vq$JgrjC>)Mvxd}6g}KY;0#k+t=co_m`a%s@G@ zBr8J4AxFJ7Pw?XZ=sG5k+OGcX|D7Il%ClK_X-wDIZ?fz|o z2NJ4dB}sx}`JzASZIg?8ZI8mJAcm)B9Gsw@u(C?c;_iG&!^4pFNIp3Ma z4!B7YFh#U?0~wYSiv&{YBpB%EE8$E;Fe%53WI>^_MuCRugim0Vm!5f+m)~m>USfS! zE41aQE<8V})}JLWtq z*Jdx-ey(G~MEk<|E1?N*qnydEj)HaNdF->Z?mpYhiL2m>o<~l8zx1F$@=}G6w2_|* zL)PN7w3%HnbSm?fJ-h7_g5+4W@9RhNI}@NTp*o5+uC_0qp*o0kYBa$$j_u?(?9p+( z{wxm@p}M?UH%xdm`$kn-q03XKEVn#*hGt5(JwE++t0Z27{I<#>MiLBVmbC$zi8`zh z*m|PYl)}*(f#L`ojA7sIw#*5BeUb{MW(N7-IqjqDQw$t`sBF85Nae9DskL6zbsx45 zztUSuR)>W2eS5^dola%}?m{MTi8i;$Cbf{zBMQ8CHe=K$z2BQr5IFrt8VFr1zYSXp z)K5x&fOG;bb_a=k6SS1Tk|t5=5pbO(wcF55_r|pT$7BKKE{uVHdk)Np-@CXNhW^al zTnqfA7QUS_phHgb8%G{nz}heRfGoq+tfk_sT}S^R84lNl>|B9dS}+o)^PcSOUYdO| z!#3<3l3%)_xB#mL<~i+!*blk}vl+MB{<0{mbFdRW+t7jJFS= z&U4A{49=zAK2Cc7Q04O?i0zR{W@jgEJ}Tz$mK$~M8DH_$ui7Q6;2%iN4$;);B@OvE`LIoFq#|oAr}Jwm#10bSD>sOwMK=_?6BJ9#0Q9rm4@@!_mv&f@#0q$biUV{ z8LPd>6iZ8u*1OZ${42c-1>?Gd-nWf+U$sF3_p()pVGpJy#)Vga#@9H2sH8*ZWHW63 zqYJeUjSB}!^|kR_)sLQ9aR=K@8Q;BoW#r?D8(#%i4~C6JoSfnqI0-x-YIlihyu(R; zV=+Xky3#XyX$kF5x}Et`lc7t&O=hoMb#}vH%6c_O`pdBx1Jg|izvAXmj#9m~qfxfv zgr|pvUD`q@IxfrC?3jj|5#xeK-Nx(3B?-c0yRQaxH-Ft*xHG<;;q0=DHp z2+DHUuIFsphDiK<8;8WqmfPfjvPxZ85@3#~O2A_xST=fG*-=q2nS53UOA_(mMM9v# z<3>mB_6euweo>2;`l;-@nS|Lle6EI$Gn!uwuQ7YvM$Q{JUw^tMmPt zE&r;a>wLd2MesPLo7(lNd zOq$!aH7yCF_eH-PDbU?j?VcB@$Il91Yt><*lYf>;RL}o&tyXUYrIvr=R^J(>%EM7s zcBBd$7Ba?zPL27`DL6Q*8Jxk|d#mDAW8W3bytvNL!4s$FD!S6(O)nNqs`WG7_@ zr|yvYR1gV_6gt!Rs(t12(%lF;1A{uxo#o7aVdE=fJi~sY?;bErUp@$yGz=vH_0E50 z`MLm_686-_dS#Myc#GHfF}}cQ-E83MjIjzkS0B&rZOtgVfO}4zYw_6(rk1s z)b8+2)2?xB)vn1_ET>r{$&>_#i-t3mF6NmVDIS=kD;lV=QOs>C(W;Il134WJ&Bfm= zRb)7k%5d7KTx(#L+x?Q{=zOMVd z{<)-0vEi<4NtOVoi0ay7_)1{vAjNB+$@HK7va}C;uc{TQEGy%imuy1B=TLM&PQF~; zk8}BR(|kW1h=`2~uQp!0;XkqH2%ZgEUIQ53{eUrlhZ&Os?@#)mvm{&x-stxEzvYkr zl1~2rj{!L({tl`rr3s9>Efz>axfnEHaO#kaSk4C&6+(5lVs|7=V`X%)+Zk94m5jMb zdbfu@(?5q3!qu^7q<)d3Ez(qDbIO2-2lVN~HL~_)|1~!LOQ`(&QE?WMc|cdI=R+bY zGF{!a)GtkVb%)_P#a=znQ?{8gS6wX2e?NTnvIVkbPfan9(xhNv>J4EQA(`sZ+jINe z@&CaXVqk)S)QiQcrIofOmK`HP44`i6hjE2oMQt zql)B`Opbw=igVkDUY2Lqt#Ez*P??8T*P;-qb!6fZrWeWQqX#g6TCYU$Fa#fPv^H`nc z@tuI+$`HiNDU40Xn}o~)_x9<+DZMH6&P9=`G;=lY|vMld6K|Cq0E^?H+5(p zuPGw`%iyMvZi3X8qU4`s2?B0^2-%?u!Tf^TR)Y$})E*|;)Tp}3irFvb?9eLH>vXiA zD^bGS@W;FJT5?}EVg0Gy?6hmgF!04C9AeoN@TZVCax(L+^a(HUm@7PbZ(EM3Q{eJzzOB!KlWivtAQf;ghVDOF1YIr4m#GPU;lTWVLm`GT zxYaj#E<5AHvp#7Yc^Ma#=KZ)c6B{0RqR+l21@4d+*)h@OG3y1W&`f_k~h zRL_5)asL6T{-k;;*vLPFjo?@9UduMy`BdzCpNu~oK6RPjX87})|CbBQbkXrZJoOUz zmzU#+IzgaPmt2$i5n*YB-<)2S|x)6O!9fuD=$+lVxH~22kd9|=N8=B_o zVJ9OWl@#MMzYjHFvw7vl`qaR~|I1weZUhxzQ1|fa9@`&9Gb*MJd<+(3qs$wJ#!kBy$V*?2t1*(47#vdnuk!!v7ofDBq8yH? zo-hJQYJKBloyc^za-J>7vmy4sCyD=6zW;O`DG;D~!20<81WctG*WQ32R9H`cXdmXK z^WkhS+FqDOa#};KfT#&pN@5wX=oCO&X?5s~a}n93pd+}h-KJX_i&uVm!|+g!7aU(?vDO@f>OYtK~R8j-30ktSnfh#w#D=a+G+49 ze{hgh_e5tTWROq$k@8>d@_)Dvm!W7if`7yNH6f0dmPQ9Tj}Orao=J(4o>7Inz8L=X zcY945oL(wF(&M-n{h=X? zDA}pG08Waa7OGqUu*lD`E{R;@jgREm436iR<|RsX!Wo_Ov9&Tk$2@*8LrJr?omrD$ z)$31x&&B0UYW+2 zKLw2b5=|R+lT4G|$Z8*hE)0&#qLp_q)3t{Y$JN~|{j0^s-<22wrw2hHgcHn$3n8dY z?Ob}XkAW=7oD-VPrt+XRXSlO{`&-?BMkNp#hy)Oo=pv|%gT4`oQyd0MzHn_5b@~UhxQ8AjvO-|oRF1(!nkzra$UVyOytPJap1Dj9~ z2DQw}PD9hL$u9@~7STv#hQS4`dXfZw{>df+&(v3C`u6WQV?ey5e;}To9ST5hTxi6Z z_WuCUf0Bw`jsSln_>Y)l2#@g09Er-ZlUTEwAaNM@5fR|&TUM?nny4zmzty^z@#rN$g}KEc92aL%{$H0Xt6fesEIUJFB=k)+H2F}n3d z60PU3G^dne@-=#m@z)O)RfXeV))GsnYN@F=4(Q+_GbYJ#Qo_3w{$>D*3~6qGtHpVg zhsbsGYC-u!wi(WQKtXe{L4!hHXnWNkBs0K7!KiaH!y2+{*L{uVVf|+d|NUZ)Midg6 zuMCN#gU~+7bWJsC6)SgMwjPE1+m@of*#0R=eNlBkEq}~e|Ft^@e4)pb0tdQ9Y2xU1 zt=S_AW(GHZ)*@3`^oDqTkz&9Q0zLa@#_%*fz{Ea*b$NX)OM9 z=zqVEjeOF$ZfdVIJzXUlvJ*C=#-v9Z@x0g@kjC;E_?w?WX=t?0&xY84&;ui!Y@=?v zj;CD9(;tH`9D^34?@Xel(A(vAU@E^CU}of8ucLhhNR`_HP7hjo{N1 zN2m`ySE;1Xk#Pa6K(VFDxSnmUsyCM*JhdMybFjbr7No5$ z;+~Fo%Ap)G2hU&wZw_faz26e|nKX;#2p633QOu_A=6X}H$jWd){%-GpRQk?}x zQm<*^T8Zjuo?{)0fw5*rl_&5e6uB?NJ=_; z78)R!f1c1H`{IJ2+Um0Fq@q!Yhbm4rzFc?&s>LUl&Be?%%FETW%B26#ogncg%sD7t zV1{1=j%gK-k|6>OA_0-RCyAmASm!i(3YC6On(v8d_@k8+Sx=@J(TuhGZJOV1^;=#r z^8Q3~q~}B9^jcpN%G9CV-sB;vE{%`-wXnT;;*(Q31XD;K9XCbej$X59C*F{eRTmVi zm4db?+tRDWTFwvTMtl^YDx!G)q!Vr|`#(~Hq9@d%5V@14d_n5}`OZJQt`Cl?xg=c6 zd}cE*Fjzdvc98%V^nRCC68u#;MgHkGvPsY<7Gaxt{PBQZA7}{q`S`{J^{<9>_c+(* zzvfGX_B0()_m>;>1WRn)>Kat5;lu;+UTR8`ToZs-$8nwN_X2!!Fa;YGzzq0+<EZ5m&%%zKJ2u7=N|pWgOBN)a-p~}qNPv4lztS!h`2=lJ^~g z0`cD5ZcSPD7D_z=rTU=r97GymSIfi}1v+S-R(k5wN}35Vq`zd%@dgnRsLrlcqT?vz^(>c$@j%^C3HW(2;#Fl9mQm*eF9A6pba$f zH1^Q>qd`Vc_wDwFX#CY1n(3h)VXxn@j{<1#fSv?W28XP@kQo?%EO-u%Uy2j$9I?;$ zbT8S^Yx|VNq8aPg$P8p@c*3Wcr?$jsCT;!3r#0ve?g!9OHIv zXZR+strLmRV)r@Aw|(FJ@@>}xU%SntfdA+8TkD^Bb2y-CdzE(|8%p)H?hNiM@*012 z(i&SWKD}G`>D|IJ@s_{!IEJJG#{P&l@mD^0`U$GI2G=PkSuhNH{&uHc539|>8et~j zix6m$IWWoj8xmj|Z(n8UmqVrJ<~aT?IA)5J<1$Oc(m{1AIyJAAa9r$7ncDL5fo znYHUyp;YGGI_)MfIxvvs+OF@Ioby zfDLc@O97u^MT2WYJd}^o$VG%oqDr5S^YVFRH60|PnWHgT+o$gFeyKT_#LUFERxb2h zvB|u+V`U-gT!@!_E8;P-Zy-Zjr>r;vg4cB~X3EMcu-~4Lb)BoKw9M3&uPgSFW&>bg z%&Scm0j4=e+!MG9Q(EO{9&ly5d7yyBxu~B zt+Nl|LIyahj=|AM^TRV`YW|W%{^4v8>zR2M2ONqzn~Yf3&npxJHG7-8ob)Y-l&Sup z17XfPNkCSkK3oR#cjqW!<#j*jWVdABj|PI(<>lxZTfVnaUw8X_z8+ zoRF+3a!hH`#|9A3yef!UD#?GvCLVj#tv14Sf}a?Rpe*@tSzfxU&Q?OPe(O8`-COV@ zm|G2u9U4UPhw4?BpTZY28{C&YBa8)T_qc3SIl)&j!_elkgMeE;3| z=L-~*jvOFN%#zPd8cJea=TprrJcbzWVxsqtY$P1sU!PYdaA1eQjY~rOHFy zrojjFBKrBQne_%!g)dQLstr2M%~ZQ!l=1yU#;5Fh5<_$+tuTfh>&}cfq_8kZXY-2} zCJ*Zr0ejpBj&<@K{U=^CDt5#GCneX7(1TsO2IaG{My^s&%)a`P)os~Q(^S+7NL{yf zqwbo4BoNQLO!rM;A#p@p5n!s)`XN8VdS~9wdb+aI+Uxa+UP9~@&%!^a?Nq#wO|hi?>bc$!#W{y2dvt(KcsScqECDt+OrIQg}5|NM1!Ut4ey(g zR8XD~treu7k&2=I4zx{X!nF<@xboBf7V1|oiD$b+#8y)NfD5|rNIXkp(CS-~S62^A zGPs*kg~cc@IZlhQ6mUNPIQX{7`JCR(1GTFf$+uK1^sH*$X1X_3&ByPP-uq8$XeQ(N>MdmW_R zf>P`|quMsOlbl%2%4YM_G-hPAxfFtW>Gru)=y{eEco!EmOqgVZDY`!RnST7CwbaU9~6&HhW{ zX!gc~p~dTHNr8>Bp;sk5#S7Eg=3XFa4>jcAZmJ z_2sn?Q?ce>OFz+fFB~+B+6ZiiC)du3mkU^}a$=A|v0L?9sf7vKLJvlmVA5>%`8XoB zOpyCtEKoim2_1gHJ~KkkTUpD1*ewZ4tmqWBCXeXyM+YXksF;!d8ExjCm_=>5==BLT z9Mi16+~Zog=|8eQ`)uZRh$uNgaB~@=K%iB?Z4jR;w9SMFhi9VSS5wP??qp_%Dw%saSifOi8&+7_5DqBgt;nG^7`rjNwGJV^n&EFr zn=+>->YXp~FxMcaK7WTsIX-IKBo4<%3*lg7;8o@KNS`dlB>169#MNOtTgnlhD4$#& zh&MzUnB?va9D<+?at3jpS+Fe?`=(XAwo9Q*xYC>_?u(SAP=2CKzr+vz7;T(z67#iY zREz1KEb39ydX+%EJ?z0brA9rh`pTEQF4iA2IW|VSSQ!7)8y3mch^WscWct>;$e8{w z-LED`m9$IrYe2hLF9)Hg18md@cYL=P7*2ZrGyU#pjjVNRl+n6xKK?{^S!josNyeox zpNzk)p;#}+W$;V6Za5N{#z4xV4!$U(m z`xLj{tF_tk7onzMmUCq@4>(nFK)+j)_cA;q7i`07Sr0ACOG@S;|Bd6viIzb*FzuAw|>QVoG?fGHJ`aPhZx5#L(9 z98aQ3KS+utf;xnLw#Xp_0vr~xN@YLg^5Cp>8IRM2znY1%jP={J12A#hYU)a5 zs}^}%X|#Fr)o4b&D*g;K5=-|>B$`3Z3u>*`rGtmp;SD%SKB5cZ-&hT7PDXQSA5hd^ zd$FKeujWG?6l#M~>6Mi_lOBkdAnml5*4!3vgC+k-^zn(hJl0gfGhH7%FE^;zPuYRz zF_&DH8RZ{r0Aqizf&18Pc=&`WGn?hl(*MXIJ8swTFKnfIT0)PJZ) z@;xM;Ro6&d>5cWNR&RN52F}U?IX%4`Nv;MfSug_-g9^H-5yY*1n3p%4r1Vx{`cL~m zculEbL~K#2RU1-}(H6X|ANXD~Q1acigI?%k(pv?dybsuF$;u){orY9j;ecaW^hb0& zEsdfNZ`gP;F?>b58s;B!S*2s}z8cZ?up&WEMEaw%$6Y-U>E_@){1Hu41jFq^-n*oV zVg-+*2uNGoBj~f>r9QnwlthebnBj$=eS;Z*L2GQX68cZEcgCD^$Bg%2m47!dlTsNe zB6DEr8@`l_ba#oVAGy>}G}9cx z@Rv6(@n@lknHC+@;y0?rjtlm03{2whKHSMeB4;u_kyxeSg4%Wa*GNkQ=p$bY8GpmS z{cMMbNOy&1n-8rL-cEe*T=$gOj#o4(2q!4*KTYeX+L;iUF18~(Tp37W$@4tYl-DTI zl7>w5-*Kd~CiBv#ZzYtOTB@pK!!uO~`cx-*nTdQ@z{tT7fBiAZ26m>u!3c^5zIYrS zetFCM)1ojpQqYKbJu`q0`f@kRe9^PGNN5(l;+Y2>&q^@*lBt1DMNU@T*3_IkP&7)l zbTXVo?OYGuepH9Y8tVXr7bKa6DzaR*-F1j|&9xtS+zSUC&vLf~ZMromI7r~4zvgF= zeD?6ll-9*X5i6<_Gx=WyV#XTb10U4E!$Nv3h@wfA~_woeY=~P}LCv=(DU^yr1WP)i5pGkNd z`OdCITBTzY|3MDXVQq_3bUj(CmLyNc70C9;sdlDUp>M1W1Q_wH5Vw2Ve!$L8SI%3O zgHzrR>M{=`>QF_)85_o!=;)-Rq1NJ)g=2M>4`&q8-QoO@;QcUg*79O6mz0(@i7kmp zP@hIGBBeAbO~_H|u3GQg`@`fboRiU#nb5`TPTgTTW0$!P=#xcZlPAT~Z#z9!n5#lf zT(`64T4IWISyqvshh!4g4KnKLA2z+93KkRGFXrftm63!b>sv2vl3LHNvsjzM7`$1X zGa+v{45$88*OjQ8TZL7%%m1b*7l2yehj05Xm2ty|8@62Jr{{>H81xOH!-Y zCd}W7@N1dAi=}0d&GX?E6A8roY4b|Fk(wWVaadWuTKi$mG?epD&RaUcyxecNYu8_? z&WIK-Ga~U~fOOyZI@*$&2D40z`7vPn1X_gX4`mIJbXY!VC4qQI*AL;Z6R<$p_Vp6T@6#Jd6k9JWQcSUhj?Z*83z zY_aP~@hf+b67S*}5+CS!$Z}KIL|k>TNip2H9iz6&l~-K8orx(U68u`95Bj=p53pbD zC?Kc7byk_IQddoOg``uz(9m9fFlfrJv1U7(WDg$TmllVnGi$q==!`-~l&Y`J!x5z3 zhP!`Z=fv4HHGTAy*llJRlk@ekM4XaE$g`3Yd**cH6*XCpoa2aIQU8TH zgS3wDB*%24yS2_wD{q*Wc~mU$pbwS1WP&_1Vv0TJcn#uw<!g(X+9nq97u{5$H{#7U?atzY0>($5i(OS@FYmsscc$Zdylc+Gi!w$f#;o@!>I-P>00n=NHy;(rWE2 zmwhHl_h|55F7q|E?Pu~^U1y$ajW?yNK>o^CkMY{KJimT0}1z$i}>6qji$zT{RNYTr1vlMG3|Qz^>@MQ!xMa;OVlu z%(#H83}?rIs7tUc5{Z;3w}Vs;HC2^Bh8;vsBBElc>>aI!jO2UF4>>--sz&wknVt z;qM^5Cr1GtqOTjCtHUu7HhZjhY!)lA>!FDwE^~T*m=&Y6 z_9Zlab8ITQ-)T7WDm=?LD5kK$^MD+PPep^gOkXBaR54$d2mF?JL<&4kr8yE~^KGEo zy^Au0$a&p7_$0q9_X@Kfc8Ih)eRu;snj1J0STFS1=;hJNV-94AaG87G28HIG0Fu_W z7fn8uRm#60&W$6!Z>ETr_KE=?nfTk7@T1}gR*cL2>NqSVwy*I(T*)hfhG++c%9=~3 z{V*3NSvE^{EvgcNM&ivY1-f=Uy0pINL-@5D-G z{lRf=A7R_XPP{BX0STq92ew*uc6RvA*U2t4MtopQ*595!_@y+Hbe8eup}s^x33_N=@K%@4c?q!HV{I=fQ=Fk$v&~zB z$p^`UAgv%K?y5#@Lq}BoB;+1p8b`dbt~9@F2M1`lMBCOk4CJb8dybdqq-+W^cq22d zDMZu!LU3^1eM!!PCyZxOvrW6GHH%nyr`ryQuXzl*Jv>yh7wIJ!-*P^XKl@r+m~x-# zC7StRewgGn)4ntGsO#_9y{Jc&n!+N=OAMKjG4XF@QVX5i7U><}vLYH|IL+1vD^wk} zERtRxHPvgfr)`;ggs!GVVOE1l3_g>uQR43a4>MVORDf*|*D(6QjZ+fONbko?(mFlG z$`ys{`t#r$M;k>Mo0Jc zIxSaw#5@BT1#jME{Pez`HMu$P&gki)1;X+l^)_-C7ezKRsXHy>jq>nl-o9ME_RYK9 zKRSMFkQDK%n0KfZnA;jK-Qr1cR1L7X%4YwHqIGH_}%0+J-xn(^`9u)~Kcp-x4feaW$K8bXZrlqE8 z{P8hl_u(XG7`6!=mPTPbjuXa!Y~QD%){g+=br4N5-(a~Bd`GI+uysEBstrjOHBVbB z1#`nEG^18;4V!`=oVH=H=;gMS_{PZ3V_9RK9V}1twi^JS&5z;#aPw z<(CrqG;CNttXDM%FEf=O_S3%}0k*-SATDjhAZ#0vlcwf@1vYC8y5tqo&u}t5A=UMY zQ03C}xbn2MpaQ(v-T>!7?((D(gCh))cK8lxS64zvSAfOSaEnWH(vt8^#p>AtGZuq{}dK4 z%OFrrOru%OwsjDR&~sQ)@=8N;PwZ}xJOjfRtB4?{Ka)|K}+_Si*#L(#H@cU&h;}X#fnBn6`C2u{?%VVr6qRYvr#wVv&HR< z!L@y&z3VPv8}+;~(+d&dMAKn5$8&{e2(Nx+vAp80$to~0z;SYK$s#>mB&@k%RE=;= z2JAGQ<2p6;Kmt9Rye{WEO7JuHwkj=;5NaQHN8KBA9N-Iy)sIIE-Xa~V4NkvE=!xG% z%wF$X6uZ~?ukzT2*|y7f^o{@ue|gt3Z#A3W;DJP&`JqSf92PB|4O3lAE0-+UfmYoY z_B88~kILhCI|rYQVU*Fs-R9^0@D=qb(sM`*v($DzMH*L-0H!iY4J%^)edx4=(4kQuM z7WcLeqVF1z9(Lt^eX^Ew1}l!bn66)TCu51KOXsCrhXpy1*Q|32xi(Hh;AxFtnVsnm z{%p4Lx_MXOO{s3lt6lR3zZ!jtj*o^LS591=4jo5_hufK3dn;w?zaEVu(jVTZp>eLS z`(x&cE=T>M$5Dv+w6Zl%hQHPkbvlgLSwqg9U9{`$fAw%$dM-fez=6`2JjEJZy}Z`dm>e$s)9Q=&unUoJHr*ri>b zWD5@x7Cx7^dTIBgcPkT}@wyR?$Dhv9Wqisg&Uj808ssE(cxAM{X0>P?U`(dYpPBF5 z9cVaX<#ak`tqQ+qz+tiN+&YoF`Qp1pxh;2jV@Ce1unx*^-frID9%!S}?c9>RvThY5 z`!9A$Z4^6m^a`nYD@hPLhf^uV$h^a=ycll+)<+0YA6WWIt3qAz&O2I1$UUFa;$5AF z*VJY&&Lr90%T=;w^@XU`#F@D4V`^SZyi*;Dom!Qyr5B0F=%H|XqZ+z5Uw($)(!K#{ z%}QTG>lb~VF?XJY&6{kVG#qNTjZ`Yt9@k1TS@}Kty1~S{lPcyK7##Mt`(YOy8)h3O z?UMK9WJt90t86OgZbv@V8@r*MxF}0f+*SvK4(B3Yz zCeCrSMHnGeJz~VwQl7_ahF#TgeCykP*^L&Nu7uoBQ@a)}b2SoJ1s@C%c+a#~`(kxBBs~Dw$O%Aodqy!Amft%xxOz*>}4dg63 zI5y1Ic42JeIg%Ab3rc#ucf=?z1j>Xgxm859k`|30Q-*aGNTi!PtCL0spK}D^n7{%J zp|vqNCA+fqpoct(`bOYQfaAJET7c#99;Pq_4Hka@C2<-nCn`b26{WW+L|0yy2oAbU zUeImoC%f7KT=r}qOaL>E*i7VwfTUoh{1Da%BRepZ55T`Skiv2xnDrQMx%i4mp)>kq z(OGjl@oyreP&Utst*KSEf_Iecy0iW$ld_v!KLbxN$3XxAx z5Yeyj?dR)6cGp(^jhBULkQ=ff`t)ual#mnqt+cATL4__`{n-&J{xJm?4|ZO+^3TKj zszApZ(`hNt6pyUQH3jxAL`*i6pJ_Lye5~Ec=(ykK%Z5QNQ zQA>KCiJ-x5ZV7U$(J=(K%722cr3Bj%FYh%j;bMdIW=cLaZqK$ik^6GC<2)C{ZXXCB z%a!y<57M@*oqj+8yhyaU{rW+X4JW7j>NKMaeT`5>lgEQRC;&_F0`lV}Tx(6xG7$5& zazZaCcy)Z3-Qt~d*z=tk%H>-+_!Rzu*D@hR!iEZv$O+LKSjw%z zJl)CVb#p00^8IvbpAa7@+~6ZRyCT6X7t%Q$3vqD#{u$$Z6|wF3iI93?ba}kW<}YkE zp&hVgF2*JD360F3^{%}bFqc3fcqX|#dGP7Mx6aA?6A@k>@dTe2>QVIfbUu3P$&+B! zJh>7dM7?bOE1}M}dC;<~u@<1Gz*^qPG|@NE!}IV@HCp=}(Rr^qkWxSg^lV?sBkMc3 zT;!IncLTuz8ZG39N#t$WA-mgaba+9{1wvdBmKyF^W9k`Bj_>nv-4+?ECck>jY%?yy z9wwyeBdKMhXVmm8y?MjDufT$pO?_6WJ`8FDV`BlsY zOO&aC;;V};MI@qD1^1) z$9oeZGYx?#`iQMJw-N0~UkQNvMD%24`)2{~H&jlJVK{|SQ=!;p+!{aUz3KBY){#8z zE$Zu-gR$-+=I>dO!8qGoq3>b5*DQXe5am3lo+ufb85h>$nV^t`JfIvg8F5Jq>ZEPO z9s@9?#1-Jg;z+-uk*;6@*2vOc<%B$@l5QVgVuSp9usSLtfUN;ydf}GoR@>PJEEmSe z0SNpvb^?IR*b8$o0_WNoozK-x!?K(HW$k9n*+VxWuz;O@pn4*)eO*cw@Ju4%V6sfwZ*!?9Hq_1N)rm&~V|CTe zk`xSAe(RJc)E-K)trev>@fcRM1jxB_uoC5RsR@X}I_tg#>!b*jay7EgmZ|LhY$zie z@ooIEsM__? z+FN@@1=p-;HTG(FBa<_|kxYnUnuyCJ`OYA0@1*+pdpnP9^nGiS;<9l~o~KCe4^LCo zaQiHw_c54BN6V7{9?iLB0aNtAuPOv-3a6K=-1aw_5(SdLiRu=wcGONsEUs-`b3bE5 z2hw@3@7r0LC6wR9zdmP!p!jn7upA{+rQq?mZEW>EAVzH&(~FAS{K!5i z2nT;h6S?b>8Y0CA>-B$r*Ljq>|8porM022tP%x^Uf?G0Rb!*CMUnF-`b z9*9sQT(^8=w`_|dd`TuZE&u*CGgkbTvxk8d7AYO7m|pHmO^e{rJapaoW5{SO1{?+? z9WgMjJT&yy(-snOU<}{(z=nIoaZ7~w?lBxSy@ZvDD$FCqA?8BR4aMU=Kr@Y}+~$XH z809Fx#%=4lGqIF{G0e5te*T>m#kz;VVGV2$72g+kwv(UcpR-2DPJ68V_%qq^jo?>s zwmV5k&%Mz;%_y$B?sH)&SnW}76^WF~a+W)M9ii8^tWAb~Z7U{%uL8{nh!mBFo-;HQ zrwqkT6QL5d2eR=4mVR@7NhaJ9+zq@R6~7Cfgy8m1JW&7~(8V@W#VUE9tw$?LUKutN z>+tU=*H~T1T~TE^&LC;DDauGeBBcXAWhV|A-kfv_gdr2D@|(pFxHXq7a!wr(%H=nd z%gug%&V@R|6Qkl?5!Ct7nR1KTE{LMR`p&bn2; ziUp;ob}bwelEZl`npUa-%l>7KT4}kJQ?Wbr=R(~wUR4=Urur{zXGyXb4B@hS3t0I`&Jz8U zknzj1f4-J=P$4lBL@UXHxuy6ifN?M})Cz(!FQ_AKW}FPRTQWm%5+IuOmx+G|(DM{1 zZzxTJl=KNCNBMAO862@uE*|UgNRiQz=zS{>2}sHLK}2R@J%g0x=v z%K|Cl+?N>x|AbP|pmeT&2@9Tb2p`lgFhv#>cYFFo;T$27jac=5SO(9s-kvQ_-aQdT zq^HU4#|$G>zMLnYy zbL{K$b5v4jEkk`+7{vjI8YYlSx7QdIV4r}jU@9oyYB2k) zwXxNKdQ0ot_NBXKJtEaUyNJkaEj-g!IKzt$#90R_Iv$}yiUVN}vdg3`Dl`L%5ljcI zd0R*%3#O=WFz)n-t1UInXGSe3@00drQ4sShbX zgBge7kdNsYMBuWUolak@_}vUqBui+LpyXi(*qR!cRM!=?EA)*Zy&z|_JXBb~u2%D> zdZu#Yuq}%sM2^D_eAEN|SdlqZh5cxjDB*{hZ`SkDddnsk;46f{E+A&zQ zRE*m7ezZq@1~Y_Tpce9^sg>7Ux;1 z;1s{g^rzg9@nhZBlEYpYzStv7et_GdxHu)ohHeMS?tk59chJ!#mT$CCnTbA`)FQ!3 z-eSQt@49PO^#jR@uy>#FVEl4PpHxGiueDZSS&MKQxnSaf*)UF<)q>$`x?x&5!ay`M zcyxkbjiV3a=4j#5mvg2p=SF$UKw1@z}}W|%_|_>z?y{)L9zJ!F<#7b;-T6;w#D|9__tQ< zAl)#Vbe8o{E#t)f)3o{FL?iga4MXgT&-gFJ)Xlx{B3zY9%||6ihNW81%vHKBnXyZ} z>8>C)aKR48jnAKlSz$7_Q-{t0qBYQXtIIgw2RzSQk-}UvSEOa0$I5`A3cmW;+VQY4 zql7Jn+nmarnD%uAy4j&!zW-X=IAybr#j}M_un?xID8&g~Fvql!|1icH8XV`Pu(p^c#{*M*izLDZ!=|MM zj`jCm0L)oUWaAa6QMBJ7zC)vEEU-D49j_s+_g?y4tC>|2kE_C>TF`}xHW;kXRLEC? z;Fh57Mhwr$t3-r#A=+daV((U^BQioX(f2JL5TXv>vbqU>LO(9|MU2$fWGiwF;=R!H zw8@g;dvd{X3Qz(gAaccilOiDgpY;I(`hj?Z@xN4p$%iEM92TIH<{=@1V|POVl)*cE zcPHAODNS7tv9XVAG}9t)#cPLOgkDPIYr#e<8!U>GGPac%Uz7SK1tVNkrj*-mgn}zN z6QH~m=;p03+|8}4<<}{H*5$05tL2j0980lCwNHkecLGRU!E}3&s(>y_TPMYUJq@DO z>v0gHEFboZL?d1yeaRJ?y^)%uMmmD!DUy8cbG-nozc@sR0xfk086ZhhVh>M3F89fEEuPG!Rv|cyO_Un+DYR#V=f-3?| zb<1ad`R2L>|PI%#NdB~RGf0+V!e zvwnQZeg>Uh&f9bS1@zQx)a;_c?C-=B-cgN>z_>gcy}##$|8^hHPA*>QET!(OnB~ua zfb$~e&26A9q<=1USk>=s@4Jb8ui3*Nkr#s8dt-YoW?zS4*E6IS+E#Qy1}R?M<&Y*h z5577js&n=hlB-0^SiYNH`??f&>X)2=yNMI;*YHf#!Un~ec*+9#IX3S1$E;ipcyP+TW`v$f2`sROQ-qa7>7Xz>n5m z3R&Qw($4}CA6pKbcSpVeP~LvQ0I zOHuB*YH8v!O3VUX?I7JQ*}3+bQ~s?On62nFJVy4QGsB_<}EWpmiNSTbmuXSBVJC z_c97EGsDFqO`#k*k_&WOq2LU_Kx(A^qr{_juwJaZ44^EH zYXU2iQtyR${PwUD(5%BK&TS=HMSvbrK>tKrW|oMJiHQto9mN25nbyS=Mo%>{%D7Vd zj#g$6krMuK@A|+90RyKG32;f^Kz6WERjWOJ;rmiq(ZUOSC<7kXhGPnsdNBY+sI77d z7-M|aLn18_k4$ogWo09!*)l>HG?q#BmfY1dPar$8bF_%%lJsQ-Ophr@?sxYhJ?hN7}&;FFmAWX<|5z0})0TQ~4ayvjQju9q0wer?B|5 z@6DA*-mQ@2ESR9P`#P!#oOr5^@$JhSmO`t;cx2c(Ew2|6p5G$@Uc|5Xu!V0+kt{E{ zdED{smZ*OYtEJSs2=!%utlz>K#-xND11!j6oVJ^w2b%abi#Y?8@3ey% z=53Eqph@%h(DQPklWRXRxjP@ThcLrBoyqA+?rE6)@bcgJ;!}A|52A?#-H{qduBS25 zq*D6e(qqJuVKaO->c$W@M(2vW)qMq3-iuHktmXwsE9R$_8mh}EbIey?XbsFTZ>w_@uuE8S6i$@3T<1+X~qPBIX0g0#wO-=rS9#I*Nlr$gfyUTCjruH4tY?DiG!*^**(#kjP^ zqaT9-pgJuu>F5Zbr+W&PG=mUMSa#8b-h?M}QZ~cgrO-}OPVc8f7iY%@j;4oCsGplO zd{|+7X8Hyj%!Xp20X6_pP*9EKYWui4_O1p9vZBbq-r;!Yut5bt<5NrHX19kGUo?xB zhl$?LXoYfENXDi1dYU+ipo=%7_%pVFP@|2dU!^hg@AjsoT`-M}G#2(&%3YQKd(~eJ zEL5SS`5Zdi(FEPm2Yp*R?H!?AHrZATBvn4vh9XlM=1bgFXv(HnXa}P;pdu#B%{`IG zPCri#mkK6aWYzPfgsQ7lu+x2O08w$pF8kHovgqoE{B$&!HC}ed07QJD?GUYX%<`wQ zEot!qu0%}Ee3edHdC-L(7AV#D3E@&TAXaIoX5n3A<)s%PyZ8K*_vAw^#1vA>oiN{BDoKrC-MAvl8oM zH14Z62`HFWao0=4C%tO{(pv-}zS>!HjOZzw{*9*+3pdpZPYHZm#MZr85S zdC0tEI*P1K%Vga58LJ(_EH7Hd*r{KBZ)wA~(!)ACG10_~pd<2Nnn6rQ=487D0Ewwg zk~oYmUdlP`9uvK5v2CS-DS8Nk@!+hMZG>amsV^8)3>lg3?wNM)-dxSLD?&m!#E)4~ zL`N!f^{h>?9B>%dRbJ|}L+FU>K4Z2vz!0EGCH-et;}qn!DDqUIYnsJZ1_!Q;Q?uG0 zdL~y6q%X|jf+z-Ob!Y0{*-1Tv?QB{s!WfDsvL_4-WAb`)-kdUc*5L__O;5ymW_F+x zWu8nP9Qe>DoTb8@Zh-X6rjny^^Uv%DP4H=+BxED-iHGi z0DbH5>p6cc-=a`_0z)`$=z5J0?grnIN04~Ec{lU#7_>gH$zUM^Y{i1S?oFyq+%9vS zbgsgAViRZ;f@&ws0XbPJ?EHRsY33XTdbBp_G_apw15ZMR5U~*UZ$)?|EYE7H1&mw@ z%1K!n-noM=Ay8|nSS+84yKa8t{NVA)`y)wr_eH;0)ldO=(qAl?c5r{q{D17dWmr{P*FTI10)nJ~bO|Ua z-QA7krn^hJK|s2@L6Gk5Mv(5VE!`m9@GkUto^$^9IroR><9j`y;MyC;8gs;N#2j<3 zqZb|OTdi*|zQ0l$#lL z$f-5Za8AO1+vFKP5#HFoJ4)WR8EP>i!y*RoqSE#xsAj6!%-0$b- z8iKpPD#`-R5?0TbIs0F4u^b)-IN|=f)bMfiXW0(esfa&J%ME>(!piQkn||D8r&veA zhD3>+b8q1Usu%9KPT)TPdB?t|C!AC9_|lektI&ZO_;IFxpuLdC5IFZ22J1@(fOq*Y zPRE1O6U2w8)z_%k9fgbAk_^z4=N~!3P7YpJD2g0~Kz1IMEYWi;AvzjM&wZ?|(lUI> zdZl;gECUtKo}TX20lE-iFR8Vp)a?9mQ9UeF@ly9Q~%u&|a@nrGhN+L24Os}F4 z+n-C*CvHZ=lS=@`R0Q}GJrDW$&n%K^Ih56$YV>&*1xj7Q?`x`@#oF*`7Zu;J&<>NN zhDuA89ftt(++3SKs2eV1{k5Ev=VPo0@Nzc7AKGWRS%dsztp!RI`*UGZ#Bl8wQ(qR< zo!W)_B}Z%pHR;@xsjkEwi*yPVhZ8%;Letz4-H$L_$eT1>L4I;UP0crM)MMg!v5czL zD62`B>PsFVu_3sJnwXfmit%NEx|Hg}S7gHh>nxtvG{#z}%2mBdcpWfjz?|yv`11+* z2h(MjbY6DOrzs?OHsv!9hxbFk$iBR2{B{`md}rT7a|^F~LXKCfK5&`rw|*Dq4_Jfo zySMLSJ|GDZfo5eolC;TMYa+F%!?(-%Hplz91f4$KNX(1Sf3a)v{~&Nk8)N_!e60RY zR!BEr3e@>kt?$-wq}~SPzu@amTY0WBeB0$=vDzhc-n6bZNp7=OAr4Hep}GF(MG z1Gi4GTYlX?&WBa1!bBbT;i9ISP0r7z3`O~EUt_%9^8@0KvJ9B8oj~~Y)hs>C^L3ux ziHhn-%NwSC;DDG-2U<~-tmLY9W%jPbOsZbCWrA`7nZ}fv$#W$_ZoY)Pij{MMZ+txj zPBvJS4ACh3eN}1n0i&M{ovU>hna}E-8pOYf7ig~wCXR+jNJt1mzz#iL`X~yV%M_$9 zmj~>;J>-fGRdWKNJwN;`2BC?2n+W5G;hp|)MNPAG!qW?xF!c9*n?%L1EFSUco ze?U0isKMJucf`qPjdi^6WjPqO-|XRS?~v%PQmxberf zr*8+>;>CgPS$Fs{vFE2J%`yp{gdjdAl}}zrX1_w-=-eEr={S`nLv(2W{qUeoxD$-? zZf=A>LnTrhK9o>~6EKZT*BrETz5~hBy^1I6p-S0=vkJU#9mcnzLR&cCc zgt0&WQaPc;cNoep&(05dwyPwu=<4vD&C6x*Mw#H7W|F%Eav?&b>k?zrd^q;Pl++HJ zT|vi94KvVZS{ppQ5Ch`%HE`?u@QK*w;1qo-kanj3(NQu8E}gb+e(Z@m0h1NKt3dLf zRp9)2vD!rP=7qK6DZx=gfhsLy{=y;=CHP^Jo9QCAG&M`w{aM}}jOO0KISI)5A<1^4 z8(kumzWKm42$*8i1fu7MCgoDo;&U^+1ffIx)>lP@Uj;p_>p5tpe753)9YZL;YUI}l zHpk-5mN&Q@k6|>_(Q}A#H+~TSiIfpwR$B^H9On&4vN5wp1fHIxClq)0LsTxpU2QlP zPBb6VG6$1Y-FK{0FJfK4aa#5?EPUuC8zH95#lvEXYgTyzaYYX7 z8`ku-8D&0uKQR4JshgoG8deNNntXuj+pUL{XlykDTJ zMYn=~^=kSYm|mKo*an@NZO!b=GAZDv&d)J;gW1PYOW#0I8Z&eZrO*=9z0*GY=8;nI z-ViTN?xwb6M?W1nOwjOfC!~T>e7Z&Kje;`Ci*OWX*R6*&xFhWmR_KzVcRmY$A- z#aW+$&}q5UpgP8n+emotSi_a_wIABB`D|k}@(NI+J^3)x)=lG*m`5X4#~~N=PVMcy z_)Rly(bJ!_bM1r@mxGVXS~tvXn&zEiVLEfh5C^{Bx;uP-1+@8O#FB}Z?6qN0pivN@ zByBl$=@MMC2AvEYmEUBPqK#7ZWbszll)aeTkn7T+So-c!Vh`Fg1#%z0bZGJ3lz=#V0%sO8hcp#G|--uQKd; zeM8TOo7_S+;tu9P+*e72U>2jmkBiK=ns~UUey~TK?X}W2O<%U%HaNVgL3A)Zgr8Ye zpV1wq8eq4iBRhO{n{BgVVi0_Be1%d(2)IW!o zIHX9Z0UO9WyPL@H5zZ~px8OhaeE-y#OX0*aVu*yTbv=Q5wrY3yfmRb&u9-j%IES{j zfSif^49QJEC_)$BQu3LWB&f8+-2~V!^L(sAThsY$PuRaVe;+g2{A`?nN76b<(3)4D z?69&pzR(K+pvC87g@f(+^-3)@Znbv-T^0jz~DC9AaTuhDN1+)YGC>{Md;`7GFV}Z%E15%af1u!;{tSIBZvyL?FJc-CCGh z>*l>ShRsmofgq#SAYKXh0KWhy%!ho+CT{kj~yTM`15D*C*wJ# z%tA$l&5h{$-9`cu02h!v&_)thkX-m(nF6;$dZQoqGhgpi0+mpkV9u8kf`TM_4o3Pf zOR?6{wT3`LAYscBQjfHa$rl6j^uulce5qTj-V>#rAw!7&}uImoEHOa7t zUO(dqjd;jF>UuiB((C$h^xHs|_PTnaud=|ihH5T4nz7>2B_rB97m(K`s?)pC^J7QK zPWB`FhUTiy)32ySZSx%;zkg@ixzgEr#zgeQ3X3aSzPj1&bGt)-t5FLGWh!uU@68I2 z5D*DXyud0pw0qM91r(UL2<0-E-cOx+F6Ez@UBmB=H^O{{?D7A=e$i4Z$|>*kLT#QY z+p5ubz-4cVh;S=`C=Uo4qWD8E@npmHbFF_E->dTeK^s52Jf{D&_e2?%q~_(p>FHP75A&u4ieG}lMz-`4v$vU89r8X|8q?r?Yg`wZ z18u+4tbO}ctM=(rSLBEpZ&JIF^g1hQMftDd@|YG%uo)1+WT19#P8fdOqQRlNFgp+) z)O*AV348ORjgur@7qV7BfRL@Du{jTF$vGig_>axE-`&1iVmX)FXBPe^+_rP8|b zY(`fNhnpdF4c{j*ovk_j@m>{($MXbWxNYPjZDjPP*{Yv|9pfF>8gGit=JIH$A%_dJ z)*Zh>6Kk^oAk=Yj8%jGYzbXjzxj68)+x;wQ$};Z->_VxNWy#V$l^&cDVP02b`1x4C zk5c3{GZe@c@F7oQ{*7g|$u5okLfKo>VyF1=4dpWW%PiqAQLObJsh;aS^(a242{9j9 zoBFxvfjW2-^YGBgeIZjbi63sfF?SXbAt6bJ#eA6U(&C4odTB{JmT#UFr5R_6D!1k3QIrQk%YB#g4nzO$yxO$i(D3;wGSKJVoKY8S z(K|FQl~Hzvb7|;5tc_HzvO8J0Ja#)X8F`A2N>t{;2ZdT}zz>1YF;n>>Lm9XhBY$Ds z@r7~bb?_xFv)+-CeV@#Gif7igm~9xxq8X^~=5i*rz4 z>3YUxpePx94wao+9}XiBMIH9&oQF6s&Ym~tv-ckU*j!uc(1)CnyP|dv2~e#vnCpDb zuNjK^UQFVucG??J*J*hT!^TbtFteb`HdPEnY=@z>fuK;+})MUdQU`zrii0ZQ%V zD!Bs;g}1IwKLNa?ByS{gBL{K>YQ%w|PT<2$+9Jb3Vrput0kffD`pMPalB>+3ZJO;t zYNBh)g}W_BZJ-hrCsx;(hNVG(KLivZ#1mLv!YepJP$0K3U;kSsY?g6Mi?j$E^vWq? zx4`^Wo+v9Ao()8aSQ_e70%w(isbSP}HGSE#F}w7qU7u(OQLb$)G=rtCee>xiu|CyhzOO`KKl*N+lx#_XvC8a~~?tuv~6+!r`!fxwb3T zdHhnV4IVe0O4*otIJUbCPn+r>8Q24LbM~mt-An^(@ zIOiDqP@jWR^&UV(4|%SIP2OCrX49|Dp0po#qZZ9jZ*W#aX=jqbODGPU#&PootG77> zqCD(nn1l=L`eD2?EG*NmYk;CSkd5(12P><6rOWW*GOy@zu4F9N1dB5jVratZIPbpz zC(D>oHEw&?AZ;-3wqI6xMX{DG~cGu%oh)BbwGNVNCCpiwuGa4Md?~t=2w%Q!98I4*>Ia0gpan9P?+DN ztt2NIAzvf;cKW_l@m-bZN#b5FcmIXSd(zXkOAe`Id?K>vs#@hY`z+DQy?uMoP2%ek zY@^a>8cyZb)edtrpa4=(`Dv$pc}ad)y6dPcLzF#jO!0m!7m>YZ4L zLeFb)B9tLm^kt`7{-dzkRFc~7(Cp=(_}eRLdbYtGHCI=QF+W)ReRK)YT%{^&{m@(B z+>&;@y)g?xspU|VG#aH3iG7`y$&*Bts@1MFGhI{cH!Azh+-5E%@FIpSjx=Y@4U|m;8 z61bam!{ET9DJDY?JiB&Z_|$oxIs@ZcW3ye!z~AcQ=GNqm-tT#lD3!HKG_R>%2L9;4 zH6~?9z{QSt`48 zpV!`RoDwaWm6q|yprRCkv{37LoH4qpsg2vki7#?N7dzgUxD+(qDVUJRfC&*o0Nd0M z?c%3L%)8_ymOF28ru3T^)bkDTOzhW=$8!afavXTR@rd|JEaBz2-ERAz z*X#69d2ZFT!=Pa3SmmBhz!cTreXk|3y{lf*tGd??a4X!cJD*%14Niio6kYRKFB9Y{ z^Q}C%CY&j%X*!xjy;32-2L*@#4pKIpH+(5jxoA`8eNCIxYUx^cbF^>Ki9a@=RJpQ< zpNf5skyYAhHQ3xo5c61<|(>e{{f5nqja~%?K!QL?_17Ui-mF z3Vfr~E=!Vpy$sEw)S6f`ZOPG3>40xGpWEPsQFRYo+k?*GSX5?ipJTg|>wsIoBtN^1Uhb)?ak)c{X_&lH}tL^T&Y$8!C3e`%>Kz+Lv#%ieS`euSYjJ!f`lOokOpNDN??@zlzg{ zKg>rjWv9>&!63ewbp0WTl?#saAxis&Ddya~JW#wY?2)Lh)Z+||@TP8aBE zkGJKA_=PI6$FPK;UiSguy^_stX9_&Cb=lXPDn z?>Zj&;p{ub>I!f6mDbk`BNI>jo`*&uC@J(15Rh2GGjA1i!0 z)VSxvXS5@4m(GvHyZu9~*tQOA<%3qK&c>^?QOmzNh>%)R^(-teIuhVMLrhdyT7 zcT@hN5Zxh7h+WA_#Cji9rRGgov__dOg8A}84NVuVYIzOg=c-N$RwMPvG$hg!W-_6FxaRT2lgIv zgK>GJwMh1pEnFooxKhs$MjH96W7Hy$rmNaLYju(V(s=xqTG#o;hacR!aRGb9inJl+^oHw(y#-PYmrc?G zRIcyV9GiR$?mYP5m5P>8v|(jb-C=NXumA|ru&9DAUajVG z``hpiD9`ppt$Ih7m>3d@FoNR=6lHzo(c3^+CG#Wmtml3Ej^Vzerq;VI=z35B8qF|C zYWQdcs;~_ZFK7@F&7%F6(#hTJyc+FSso>4xz!=2o3h*JFXN?wB5z~-gX!Q=9O1R6t zs>Z;=WS>skzp~KgHTLy%Nu{R|vbtLW6HFw!4~v z>`QmwZDnP{CPsC+QoEH!y1~;X3|wgVex&;X!l;tTM7#9Y_)gs-0qG0 zZ~ToHxf8k&E5%tT>Xv}WfsH-RNbdh&reybR!SY_W(>{AtA-ALgg@D!pq zE4PA((hm4QKqghQ z`4j3Q0I2IqI{6do=mDq;%RO++p)Qp)ml_J%%kE=$fd|x1mwb^Y{T7oJ5h3h<2 zRZO<5;^9iUFUdec|2P4Oza#cIwop9zi%b;PaD%L8vi_a0VA;z1Yb~>TR`9$9(%uTS zOHG4yz5HT*mgyVw@}SKzaj-NC!%Xa%#c7YWq4mC3v@h3b2EXiw?NOr{z;+c0eF&H&VpxA&h*M0Wp?2(p)kUr!3zHQi;3!B zxMz}3$XJj5dQWCr+^!xH^L)WJE-x0w7OfIgYj`}rp{}R#3Oq>7Y{e1I72_tq2wOZR zbgU1XPY9xuK8GMv5v2537~ubun9V%efz1R>(4Q1ei?%S=JoG7?a?&Qyu@y}+yv7P? zqrfb2uz_BC+IbCD-s*r^2YRJzoOQ< zauT4}HS2sN}3}c%g*SD%OZCxF-quM7Z;)1lG>A>t!6QFZ(u~&*@E&CcC zN}L&hf?lrsNO*z8N%mmqn*rD=yQ2PjXniblj(+l!sp3?Q^1gHgsCZU42v*}|nKw&jw9qq-U~gD5#F_Wh zeZYzLu%JZAh*Y(AC_#bF45y9B>(qJ7YYr`-qrib};cJ}ojC=glQX0X*(Dw!I@pKvX zPg&yo*`M@yvJ>pF2tRMlOyqNbVQnhcrKF%HkW25;t-s9Reyf08@F{wT6%pVPFr*%uem8Tbt-ru z(<0-0#J09;2oLmWekB0Uqac8*u}myMjjBqt3FPaAiqRC+yX$3ns4DWZWj5c1rbT!4 zq~nU!nEMN$Z6_nd-sUfTwnp@M3#-K0)wWU|-UnuW@J26u;C8>+KYePpeyGruGLO-8 zpfuW%_husvzR;xjYN?X6A`WrNmTg;?J9y14~o*SeSxid%0K_;nc9z{{M9`HPN2w}mzGLRW+mgX9l}T! zDU5}%8k*$Vs;Oq4Gm7=5h3t^FPgoi{F)y5p%3 zCq#=P3HMzB-xaEd7_Y<8-!H;17Se{=kH&wak{pkZkaqu(?o{v^35pPz{0X=J3X~A@ zZ#f600^~Yyyp6qMn}+n;2Dnra?wkCEhx&s0BO{>yWNcE zD-K0EwL%lXtAR-y4;9`j)#yqwOrr>Jt5JoAq@-KfWr^FzZ zW1Io~9qWBFQU58$tn8Xro#KiLRSlcl!c714NsZANuyWr^^f*C3<{g6j^@r*xQ`4ZD zhelh2t@tB2gL5mzt}AtBWn?6zbn^5&JheRQEY%xU(VOPv^TzV?`L8s#app0Bwo|>B zst#2(w530C8c60e7U|s?oU2dW9E3GzSEX4TMD~?5`to~^O391kux3k=R`T06f?i;n4k$&_k_)FN3`UK*qxSJXFQTNH(0 zuMEs9x9%7QIV!JfDnnDnR@^yTDH>rK=FEMfAgfBu-IVN(F!e^)+Y+Lrwlm*Py1r?> zgc9!UNoW2;7eU@_=N_drK{eKhC>VYT(0q^s{)8{iC9t2k7F`_2QX7C%3qLhj4nq#C zS8}n?j|L%m-Ze@$aM4k(>CSzt^o4*zC28A~)xCZ5;tzD=9Yx_SlXEIUDSA!JE|yh< zGj74-R%;YMS;+exnD6Y(*e|e4(ayW%dEZUHRLw;idC3z~ZSUMbP8rr#V zoZwVg=ABx|KPz{@N#-YZ9s%t9>#1sDu}s3mjb{dYjY?(P_xaWOe#rzx7$GT6&rLm@ zH=a#nTr9r;V_m4tLmvD&&e=7ODR!abjeYy=2+WE8QNe9%NQ_$QR{(KTpe}KB$$lGq z1jVCOE{PI>#02f4L{(UF0^yWytG3*ApVptwT;G?0YNhjz%bzPwG{8Oq>NwulUJB?% zg6BSZzv0G){G&FhBXe(r?g4Gs49RY`zvm?6_b)pwF|syr?+C;XX2wrGKpGZ1l%`T} zP?iC&REfBvWDcr7V=63ey=`RXWDo(!dZB!`&1;jDL;2O&hk*2z1UgX5Pc)frlk|(! z52R(A#ccD{IW91tyHPqAyZ4e`V|&)j5UW*{`|#9QWJQqJr6?Nu2k2=X)z~~C5(&Fz zawKG1A>^d9OB=!J7~g+mnx=jK>LxJ8e$vnWb$8&gwCQ}F_kyjG{6#1G;R1AW=RkJY zm4z>Y;sp<^;`2D_YCgh6Ad3^LMaky*qbrsI$ZAcWAK^ETzdrtL0z0U@Wl`#S{@PvP zBB*dmMPgGaVLSMFjY)KiJl{R1S!xMuH(b_k-kPm-62$i>@Ft5W8mA@)5yf2qt!@#pg&JFf1uv#o#Jk6!RFkm zDp${+uRZ3lZ#A$$!GK}da|(?`f7E(*oN;}qQGpF^o`;U+M}E@|(1`iop1b>V5wg0# zQs2@ENQnSlmEW5E^O?MXgaCW>epu~61mUkSiZ03K9HzU`4dThrvrSTuU zCko_*O%b?jbd@dvsC@|jdBl4TlXu9iN}W>FX-(({07u3h(M$_k58@% z2EcZ|Lzw8-LwUskKZ=s;6w>=AJ|Liu9~nX76-o)w?s_VN(jH_*qxSfpT~dIxScgMw z{;@Umk{j?3-{Ay?-**1@RlUS}bf0hRX$(Z5qZs5OCsgyt=Y4(bda5V%K>ZSds0k#j zlAOz>G3P%k910+U!W(pj42T5+iZTuACo9Q#f0=>O29ToSOr7<_JQr2mvF;LFH- zz$2j|bh&{n9|od5h`b%+(aY2U?R@nOl>Ec+V#t7HXZweD|B*5LshMABAc%Z#Ou$ULfcIOk zh?M~#NQjIHNWfZ$810vQ2(y9hII->JFGqiKj9)wgia>!n*q##{;;cKbnW4tf?{Bzr z35kuhkPWg#Y7a6B)$WScel&PO8bItT9fD-PO#;}_h#Kn@S^Z~|AcozWY^rE-u_*ij zLN5j`Gv=o*E}C}wM7jLo`2}oy&}Q)el5i$+kB$MUIDTf3-=ib}{RlGta@Gc5Vmsg--Qvd1b!5>wW6U&`^iokgsQ2)jm z5ZHJDHwU%?e_OjHB$;ZyT6Ks*u`Kqo_`*aF>>uc(2}!1)l3K2UQM&XAMC$Ax823As z{5zQO5>h{|wlm;2$v?ux|MOS>@#*tez^ubDf1CB6Zu-y8`}@gILceX}zdz$oI=*!Inev*$NN`*GpfJ(8w8ZxMyy>>;3WlIa0}4hJ#6RIg*KMW`xwC;6dD(y5~Alr!@U24>;DO85yX!|K~nfhJh#U!Bk>6&78X{8%b96W5mZ7Myw}Pj zRrR&;&8C5jQ>=cwZ8%$c4dEz{{}StXEgy+S|C02=vw+JrvNH2o`q?Ju%83#UHSt)A z*oDw+dt^~tNcC_N?lz$jt_@if-i?=6#BC`9m^AeNTzmX{B5%-R*~X}jI#qPBF-sP-88zh_Ix zJ&y$vbCe)7d0^Zy9lc-{f>*Xgd|w>3T3JODJ&CZucp}8K9(%au=xHdM0%HoB#4RA; zqe^1(|B{_#VE{?Y>}qV@_t&Gx`AinZ77=X?_PhG~bJZXfIDQ=h$Sy&^loQ%=M@-xs6F*JZH9v>;0?(vR%+=yx$?!QREoaE7jg??I@6%OjTTx1&{)|A*- z;37LQKos)j@KsolCITOK=`8#zX4pG$DVR@FSTO;lP;W#2-TL|`H+=h8j-US^rC=4S zauXjW#y*67G=-^BEn39_g~%XuO8!Tlg&%o_<&lye#1)VBlKf%`@NE2U`}u$2nH9(* zE6DkO^2|UWpWe1Hm~1dxVIbFFXOgd2sI=o4fX@ra>jPKLkME>naNsR6f>t2R~4m(Ac0+~E|&^0W!j+l)!4z;gc8Gytm@_WG$iBZ?dtFR5ha-sAQDY{^7=xk|&q)2)}yA9;gN zzq?3|dP!s&GWbUszcUUZX-m;3{%`Kg{?p}>ez}}#v=P~0Ga4Y2yFJtx;-LUxO+oNB3ld%Nh0?h zfMmmb2GE$)2irtHp`5Eb=*q@BogYWk@0Yl)J=G|UUzJBJVPY$epk0j=HUk{P= z;Z2zZ2aSd33Y{vY)EDmZEmihf*p65hW%9w-2CiJcSBM2f*2niOZERx%tV#;IS}~l1h_8- z?SmaFhwU3YVY=q)wiLf&LN!^aOgW)M4C?Jp{>hLpd*5tqsYb1)rHA%s4))^lNSg^FJXH35Aa63bDo~Vb zHbsA{Tw}gqU#!}!v~*>ha?7q`lP5W^#L=sWnS#zP{VmmgcdB=;+N``HYQ?9~aGyP2 zu3k?3lvuCx*<1OBkeIjf?4)yQF`<%vyVjvg(Ui(XguRzn!0tSqmK)`2%e7`}gt>!& zJPv^tZ4Z4o$vw{uyavZISFU9WlQ=^5(oTVZE zruM(Uc@~pbbpy9@iMq<$q)n-!>cNcRSNz4W5Bm^kLGOhHIWKyX_BDi&> zDb_2aVxjtdAu7jv5w&MNB_RtFyO8RwRt$o3)8C!!>zGM-#ke}u_vO{O6orM z9KtvsoqMCEmsY5@709-voYuF*hl?BU?$+LT7Ccu9az%uO(Y{KOI$qfFrPn!UPo(D& zn6g+{i851sOVY@k|0bS5S?k=~i=%}S^?QW*KLjr-4GDI;@i`@k zFhIySY_9FcZ#=}V%<$2|rRWs#*gJ0G){jt_JX`gKZN0trUvS#hpa8ukr{#XgVl12i~z3E&^i*ul*!nbl0a1HtT&uW8z+lxFSNsFWlA0ZW*|SQ<&o~=Bh(e?02k9 z?}(Zg8Q0-RBHF*Iuq<8VB~TW2zu4_BR@QtaIylD^*#itBr5G2wH%_sQ1=Dv>=^+rD z)N`yGEH?HWSUA%3jpu%?4?1yiGwgYn$ZFshGqzituKJ3OPqYH6ahm6Hp#Pv>pfV>k z>!e&nv~WWTUO^!m(HkA7bR|W=qve*>K6HWJu7*lIar;ZArOaHT-|24oml%p{3L*X$ z*eswu${vegORx2z7vVoN`A4Gln)*p)mGveUQr$V%k+J8j(cwAWJo}IQO@kQ@f;vuu z@KqgBkf2WTb}MlMrpq|(eVWD zJXQq1f!5R=K=kt0KLhc^=-cl-f>4w>1v?i55<~RjUC-CEZF9paE-V`Z_SKqKTd&@C zyI(b{eP75o(W|d9a;!zhnf1j4-?r65QfL&!aDLdGB$Y-kH|WPOGRTfOZ*|zuw14{@ znsR+oeQKR+2muG%*0$2Pm;P#FlZW3si_2kF#U*W?luApC=ko5*bFH*@eC@OQyhOjD zl+E|c)@g7w51DkNi@j;DMziaNC~Y4s&NzUFn?<^~j+cCO&_R%`tfeAozHUc%w9>G; zN}*A2-Of5okEeRdOT>Ih9mFw#?X@8yoX;{3s&A*8-pLoI1}ab~icXIzCmmtH3y6n87q}l*$RotPV^`9mkK{@JwJeL+gNhO=tKbiP zi{N}F1-xBte)!PCL9Tu@&lcM+nwHT z{Zi8pv}qLZoWrT^*S)#v?%PEzA#M-$`I^P5EyG6#HzWpM{R#T3w%NEmfT^)`a_3LK zD#|a;&`~J?+6Dud;JzMftp*2m`*xMFnn`K4IYvoL*47DMY4~ZaDFMkkRIF>RyQqOl z-#I0Y{Ws0D2N;R!#9`A#DyuJTfuaUQei&Wxhc`MJ){DB;xxC!PDw9eCkJaV~(m43W zG>V(Oa*2s1!Hw4ulHRc?#Uy0dUMmgwU#dw}R+jLXG`Vli$;oZ7KipIg6)W)7QfmN{ zrrC8HBZi9kDPtD|iS_zE6Q!C;X$^MdCew_YMnP8{O}De2>&b5m7vCY?X;IELk{);+ zGvvyPmC2>p((_36{`1~3kbJ!6@zMw86nOGys5u{2C>6ib37ng@HgX=#_^?rB@Z9Zb zJKq&MX1iV49N*6i)JjmHHbOh+YC3Hz>c+i4H*Zbm-6>Y(OO^!*UMPrD_J4my66Dy^=)Op-mH9vER|J18@T&+!WAKr ziSDX3=85uofM-*t(YzE*bAqCD$ec8r>a#%NUN2h5Z33s+4CSXDbGDxK-WFL{iQ0vr zCK60-2=?E2SON2kmx+_9f*Zz}dfM)ySz+F~TPhfnQ`hRa)2rvFv?$$ISo16*A(5q? zw}UDpeNqc?YB#G~GVE|lL1u`}H~Vt+yTxplZ>}B8C{izm?@shbgCyB@(gKH2CBi#* zdZ;yQo{N78HQCrGmaDN_$^;u5P+LaF(H7GLtc<+a%*!Mk(jW`3`Q{p?fL8Q!I%R+^ zK=3a~!QX#VK9->$g%ZnmzM3M!Pr^I5Zmz@3o!!&gI1hzL+Z4=;o!XtBF7r6z`pN*e+KXnUk4jBYc;GGGFHu&(O(LM_ z?3(o)W2&BI$JLHa%yhcwKpMO`x@oY!!!)Kdtg*hQ4u!BgWz-|(PT+VY!?y8(?feJT zgHG^d%4ok;=!c^d3~G$VtAyBam+DH4Y13^s8g7nFUzX+c?c-i$a6cq^N74IQfmljK zI|E7BO063QIx9_hOj`RXO~av7+LI01F`$hL(W6qz_)@*ttKR`wT01Yu4iym9R161u(gZSE%!1(zG`;V>7Iw5 z+huS3ac%>VY93AMUJ}RER_8+geOl)Ia<@_S?usqDL@~C?w2eN8;XR8?(U8f*J-OMW za@x9C=Bm@olN-_bxIvOie>`%VP(MY50hlgpE|1$;7J9x#3B?I(=0e3${v40U>y2FS zBBw^yMU@o0m6=j083(f^;WtOlnZaq}bvJd7Z|A-7OCvX9!YTzhX9L+WaCb?}eb@Xt z_L|mw)3}xkd$z0NRvuBFM#HfW6E1P?h!e8uF@s5rF=%0&g4mP84Tx^Zx6G+flvBAY z2?83|4MxtERCG_V98pki=PF*bhZZk@}k|wCgJ1tH3{fW&BYn8BWZk{{u2b0p9 z9Jut(R4J-XQ6Ey(s%ISd0fTXvabQH?{1g>edgAg*DMxiim6c|1J<+xG4z{v?ps?yx zDPcUmnoQUN*R^|(a)b7ju(QdnE>>wt7)H@R40f!z#Jn&%a5#E$eNSNb(p+ag9vE7u zYb}mxnC$CmlSt~9C!m&97Aw@B%S5Q(+j6ma90|M{v>)RDQ^g)m&JAc)PWN+iY3)-O zWC`6J+q6fXtpzM*CbLR1wFlYPv2N(YG6L1)Kef|<8dDYd4MKZRAf~3<6;mRsa}d>- zymD-OX3_W;m)y&Pn{<9l+k->fM#hOVMgpbX28}};OIv?sO^@=uIZI6E z?bzwHc<_Bh{% z{Nv?3I(pMn*4g*d(uXZ~IiSf;z-vZLi%wm)=rCDKzpB-%%52Z=6DM+5wWK)g1ln{~ zEBj8O411}l5j@Od9^d~6PhpC$ZWtODopw^KH6=|ADw78CiG19|z*JDBYVQ+PRN=ZiR-<>H!ga@K zh*5VNGSO9%Jw_vwP9z#ez}g5p<*b#+yIp_|&H~HRZm23YvsbT`EOB*?`8S2V|u7kr#V;O3EHrwwX=$?cXYN^bE^HjlodF(}l2Kbegw-2(JfF4C- zvi@Ka6a10ne8Et-_WIH{m|kLZUPh%SJwU!(R~)s+g;vRVdy8C~mrC-U#KCG_e28OJ zpIq{J4!9J<0b{G9agZ9#d}bj4gB`yrJ;mcTvOXwk>Hb5?W5P8W)WC|yq~eU$F8B;B zhLhen4~<_a=W6u}(ySJ{V_*J9%HgWw{M%Y9m`>jRM%YQyOunnABLU(K{S4RsOs4{btz;Yo89i5|8DmAs3O7yI zE?t~D)osW7}lI``E~XV->lRT-C_>&ernUEJh0@JI`8)e40TgRuj+=y^?_ ziL1IK;zg)+z2j|Vc=&|$8#hJ6Cm)sv-LOjL*yRF}yi zFE1l)^h%^thp8$~24%0`+a|FeiN5PRW=dX*n1671Kb8RnYgXUmZLGc{f@ zs`1Kt)QyoFn}wGupr<8Igld|85t&XM+OzX~J$z#xJ8$aGM^#Ihi*BxjnmT_=9w%)V z-%n+7ANPxw4N6zVjM#i=dUr><%I2yFgTK@hZj@4;*tR9c^1nmvho)*=nf;a19e-=y zq#w&WH?%fX9#SW|O?CM(FK`mbJPZF7snN>!Zk8+w=avJPoL<`FhrbZpX054G0dEW- zqN;*>7l;uYB58gH(LSfBlq94ZTk61b2cg*9`_13W6em0$(i**z-}-7{Y`mzGmF2K% zpAp_&e`GMhO$t#R^&y*-gKL#t{fywp5<#;g(+{C1`uEG5VJ?HWT2X#fk(-Gnu3Xic zKEqi98V(91+_~tKBs!+~AtP#)v3pe6l_@MM$pf*!K9U}*095d@C4x7D{P%%tnJ*2d z2Pv90oGx1N)10ri+Gu$BYi^hHvI~B?ZF`=TI$i|N(@W+x9_FJTgLd5)Rh5~>dN+k~ zC&kakH`9L_1Sfvm&B!o(6)(`->d$kjX?wxNyg$gH?6jC~$2?NS=xmjj*=l^1Q;F}l zYpk4BbZ6Gef|Gql&N;~MI`-HSL&0-C3N1)`!~6Lx@La@aNBExN9JtM}KwKe|B!4;M z9His>I|JTQvxkZB>~5DqU+)yp?6CaYlefwy@NZGZV!(&7Jo55XO0egiC} zJ&}6yMoyPX5K3R{i?Yd)6a8rs{S9Z*_K0<+YhBd&;w4<+>T5$oSQw?Of4`@(sVLt) zL2bBD9VAIQdMhHgjYv{j4Ynca#Tn~i!zf)msN|O`%ZAl z8P-FV1@CEgCbp-pjy&z;(6dAuGINNosV0c}M5xy46qY?Q{60KUQeo?j9Ry2JCUa>s zmDg0U;c;nf97rIQ3|>p{Ku=0eu}i>L_V47Y@INXON$|Mf59E#JkofjE5U2{~UQ+gk zIq#{bIZ3_<#aSSbuoLD&57u7h3!W73b*CkqcwPZJ&|gr@JPnW1$5Nu4g1iKX1JM4-WhlT-D`|-T;#4D-;>Ik;CltldTLsaBwQn$BsmV-RP>4?44 z-LTN7x#0HEO^!e#S`uPV<~5IIG@p4j{go*KVdPjdxH>BD%6Zpw;-sf+k9#jw(Pd6pg!6sc#b)lQbaq!QkrI(+y8;oU0bE*<_yXo!q96y;a~>d#*sg-q1`=IpYcxn10q}eP@M!$DK7WM7qi5Q0)O zS))3C?tbwzK*IW{a+hMO<6=?^84UkUy5-#dZ#Z0fF|)u}2xniVdCm zpz$*=Z~f?8`e`jFtbF}2ea%N$yGoA#ehsPoai0k-cT=&;D!(L!^3(kLXIFY8*wfAa zk_!EKh)Vv5We-96=L!4DaWZ%&*EALzuK2Q6Ow)MGIAVlvOuR=)QzMQyyD;jpIJW&v zRXL(vQ1IM)tHg>_gI0`; zOZ|DB9eQeqIP}pIa>7^gt0;>K`$%ifw^#LOK}wPWZ|jEsv4Gk4_87ydyoaI38dcTA zb~|17Q!Hy|^A%bpHhz8UUp0C+WdXbv%FAu`_?1q|`2@9(bPZ|BMaz-pXYoAML=DNq zp0^JdcZ}-}Xybks3{#a&{G2`4T~B32IJ|8#K0#Wg=7zNwnGw2tc$@;Gveh}TNk#PZ z5YEcU9HOF_$C#DZd86;=7n>v|DOV&WCt3>?cJ7mfP=|eV*UCOr)$6FFr5I61%DWaV zvqrsSxlg$!^%E(zF8w8FMnItaiQ(f95zs6Mcv|jTP(vpO)+u(GZ90&I@=uVuLY{Ch zP0V%^_c};jd~e_1$1;=pzc<@MVKIs?@XH^83H$RU^C?O(tkU;lm%Vq^l@yJP9!Y4w z31aFOd}mO+{=BcBy;NC=O*|hO&A-&eFQ#u`^r7lg5^ZFgSddqp3qfU`q2Ojg!Rw~b zy};VJ42S3AhE@LP4%6~y^@r1OpL`-;vD$UeKBeZEx>$D0Z@m~U^3zb2xyS3UbWpAP zG}U{I(WT?dnRi_TT2O6;^{qH^Tj_&KDjgH-TbL>_9q?2A>ZR{x%>_?nNPS6SUi)g_ zbiQSb`OL8BXkg@`t|}G|j*-^OG}fQTDhd-z)ZNTkc{>E(1cH`c@R#Zwd)Jj`IWFjZekZNuCEIezjwZ;n zD(^<_t27aRU$A{6H#+NKB_jJ#>)`tvJWcH@D7t4u6(i@?u7T_Hx%ug9D<+w`#FGpj zmLAyO{;||ROd@FBXbP>U%AM;GfEUEr!ji3XVseD7g}FT$*xppX^F6G6xvsW1TZHJ# zaPVUH_c5}vyivNV%PaiJsudgqbT9wZWLQORiGiK~1F^Ppy6z-d+(swY z*5`Bwz&%l;G@{W1ep+!Sbd5(8Wj$h;$c2kxL6yrA4pFBZ6RspARkdF*qy~BO{ANa> zwDD2lU=k}<0zUjWK}G_r-m74rui{m~%)~D2WmDsu^F5`dqW$L(Fj_)YLl4M#Gjiu% zT)s{JPi4}8@+(@O0ulbFc;Xo3IrvyWkQjyu9znfr%$k{g>5Ff5l)+lzrD_YQ_aTCK zmI61TNA~uTmv*sXRf7&1mh9AS;6TlT#kEOV+26@qZ`N#RbC;Ny?n73uT$3axFD%DF zOBx-12MgI_nP0Eu&4nm}+1iMsmEoV{?#u@dAVX83c;~ z6D@0CLW!R3VB2mH+-<;9=9Q_t`lYq{8>w4~m=8@hUobS+=kSCP)!JZ#@(%@kz%U3? zp^DPE&3Qn>e^POU(ZExtb*_f6bEd50rjpOHe5%xWvF?^|&T(@clnY_|@#mpH3)t!= zaN%<^fsp8sB291#wjgx4gc@j->`hQBsRY5;g`A97l%Rv;>jqguvaWQH#v(1KS15Q_ z4Fo~PN}rZlrSF#7)0cc^60U#yYDYL(HqYT|r4;Bv)Y^OF>=pRSy(Xat@gN>1cbd0p zZ~4a-voaDOok^#oE!Kd76}#w;EBgU~yvD^a;Sv4}qODeL1`%`U#rV@%ff^KJG8tgS z<|l)Ki60$~zHhb|Zo}~yiR3VF+3gpx?=yl(?3}YL2fVJzY#y>oJHweKH;tadGQUJK zH*09aB1>pEJ>@qG%I>{DzEOL685}R3zF0O_eLbleR{VKk$S}y?*TO(QWH7s}+J~2S z);L(H>?LzPwPem=R`T;|HGq=pNvH}j?q$uIyx*cm*OA7jppXZ`J`Rs>3))e)4m>WL z(@>)+rK-b!zE?NOtQp&&mDg3k>IvGFtY^g*nG}8;WjGIUSR8xm?KZ~$*qObQy6U`| zW;1S+Z~S!;XDi2>%8S-Y-F+Qk+b!Ye87_Fgq|Qq2t`dLV!tBah8*GY|_yz_;2g!SO z6D%)t*&+O9%&mwNpUGrC$8Waji3k-e?&$lnvSpKscIVIg7o-2Nm*}bJNni{NJ9RP6 z1&2Q3I^k;M+Nl>uxL7#*UI5nAPDQTVvP$tX?6;wpUb;>$^^;)9H~yBv!aL#bSuI1E zbzNJ3>j`s*Kzt3IMi2H!XLC&bu~pdA@5}`+77Tku1>Bt1Lc3O6^0w>uH_f7(cX<6* zZh66XpXk(jfY~=3Vx5D{P6ILB-D^s`$HpGa_7BE>?0fBQS6yfAxDjMva9jS_-Ua`S zv3>ZoLC#|qNt)$t^8o$8AV8gMb5)#-q793y&Oqo{TOl<)pb`(~vN0kZr0wi2-_)CPuzIMZ`eV-MU#$Z#Zn1*F0IovPy(hPd4|?wJ6rY-nv50gU7>OS_Sxw6X zO=uIWnkrP&-{2(W+&4@`iwH$ z{Bi}UZ3-2$4|91gA<@)gSmEL3IvGH(TGiDm-30T$d_$_z0}0HzSdpUc=S=qd^>Y~@ zU2TFUY$fcwYU1oB>M)(WG@w29rrk~d-*;^9seW>q99Ac6pK3NasUBaU(dF%CgG8vm zuJe0ZY(Ds;gF7N<&b?g~%^?Yf8jRF@>=aW7y3IPVFs`yQbW8P!r^qbp$Q`~kHYCUf zj2wUe28cdzlxVWQF~6k>@rsWXZc*dXhT~Elsnpoy5_cY}PO5H|yxc4>Kx@!7^xJJAur+$+Y@vn{n(Xh6Zdex*P&$xj2QCZ@|{WIncewS_MWn zNcw*CJ;3rheCcmbz-%lVu1#Xnx)SQ#X2$B`{*FVD+rTXotgPXrmd9; z4=Yr`!9895RJ45>JJ4XE@$;zg$P6|bSw6)WX@+S}VBsoz|Lb}8_;Ke(=}jwjduS??weU_ZLQ*wW_ z9@u9DA_Ppm6)6dF*v|1$WQ4=;k2I`N&a0w_iGXUbmPz-h*ro7%pcqL}a^m7Dr5KHS zRW~mb;R{4Cz`VY5HQ{1a26J*C`7c*m?g8XeB~QId^xalo;1B~D=?%O~8@CV=>RGlc zn{efuFU}$i!c|o*=|D>2!P%YIPrAH^txi7)t^LvYn&&;6Eri{&J3E~gDU(whVacPs z(m*ai(Q(EOi~a!D`#K%`z;qIFWlIYJZwTt69)g>EMXlY7_u~)jIrTnyI!<^{Bca8c z z#(mb$T`$IH2-t_W3=A=hyb)akr6KZ@tDYo8cf+BfhxMr8Y56#kd6_2G{$Sn|u3D7u z==9T%h&BQA`%hz~?bRD;`DP0Wag^hXB0p9-h@e#bIh;>tn^A^-&Br^+IjR*!*xG2a zcd&fxWxgu#U)eP91&Zf^FzV*~*Kcdgl{X#e!ZzL!{r7%>?;B41OFmw@tSK4-Eyjb& zIC0wv)StW~kIOzjr%PvT&)=Or>8@+;clLeMOg%GGpZ@*1o>Q}*=zWw|6Gx3#uSct< ztFfpY=g#2K7zBs74Na}~%>!V(oyT|2C*f>mYUW?anDMUz8qzgKQ}Rp!@t40E4z%}? zzFNFh)+Tpq8UmFhMQZ>nb{&9~r)zzGB>qzN`$O_C%}=LxCitE?&T900Zy(3O_}y)B zA3N|7>ILQ*hz5Sm=%2qR?$_8=b944Ymi^!NM;6O{&Nagr<$n9-Kf6}D5t?hi86`1B zt&YD{l>Vx<@BH`gp=>k>gbYRL7bxS+`S40{&|P(o!N}OzdgIx74q(d8rAMqJQhbtG za{+Ehrp$?wAb|B>_3UIz3A{XdTtSxi_h=jKQi?OEmjayRSFd*iH$-X#Z?Qr`UFZ|^ zKAmz%NIz840J>CUUi+8H%q#P{d)-kpnb1Fb9`X`Aaz$V}Wd-lU1$nY#uZLM}-LHC5 zrgBwP2yhe*{+6-6yUXXl;5Sf31unULiQ+oN%vC()@39lV=f^Yp z7{sX-DWUlyVai!%{_LDc#8J9g7SIhRiWA|y_fq9J{aaI5Dcm(Gs(ETNFvP$9VX*5Q zRV_}TsD6f3kYUV}%?jGY<8FckB7xs;BFp zn4|?h3^wEbd0pS2!Ukcdr78yW4d}*+8?c@WvqLFau{C4|dqPH8v$_?)U+i0ZRb8cF70Vkw z4sD1@Z#uoujhWRRHcf=DODf4Rxwq0E20Jm*1yO}8Y9kcFm##T?k z*Lcg%PoZiX+bDo}k3$+gN~)L%NB*b9Igqdx9gAdx83GWKw7b8Ar6vo@s~Pd>J!=%* zS@Y-;ds-m*ImyhjEDyF*KP@tkoNu35!m;M7(_?Q@*ES(dD@p;PYob{LAO3!v5Lk=eSGhL`Z#oo9{iC z?6tW-MbEK5OSfX*ZmoOA5w-dK6Cv1(jtzTCv)-)NCELtfayJa3v;(P6kU6_@kzYZbH>gvs5P!e zsXmbvC`JNN@1?Rw=3nMfCSx*N0*mDSFfq@jYu01wshkI|aN~1eqp#hW-O`MZ-@OpB zwFmevOibtWd2c4U>9&QJB|N-`iE6s}atJ}lr!$VV{A%UV0Noa_U+`Wv!KLaugsa`( z+CRfv>kBn`?GcYi3t7u7>B;E+j3ZuZ;OT2n65BqhQeOjb{{^rFG_kngU%aLvNzEUMuQK&pnuP5DD zEbBt+t6eTXa9Ev1<&@Ig4X~j(?l*n<=KnK=LjeO!D90r{lkdKG4fok5?%C~`$PVX` zprwfUUWK9>C6%%sbGmPq?J-?ZB9KS0Km-8<{h4wq)mhJN*|(0X-0^^G;3zvIgsoJG zsmjE!Vz$mnlqQQn(g8-zCq?+V_M0*@0{uK5Yp^^> zp&-!sdNWj43VZh@cdO9i-_?BkXF?6e7Pp36OKnJoR@U>!W&M`d&F{lqCL5_tOq70I zPsu71FKzuA;s?}lLe;AhS2n_o-7+mmJWmh7DJ-pN^NhhXzFgFMVHlcxGs_0i?n1nP z{$7xIxZOIr**?l=g6O3YkG;t3ll%)6%~eNC6>c~>n65e}_p5S)W#6NDvBdyDuVr;V;Ll#n&&X5q28-KhK^ z{vI?9k_>Z`OIl9%K{kvTly4xs6dux6v~R7a^tSdM>mq65p7k<^V;9LyR!Bq#MyLGt z4wjf@-Rxb8?HYZj4B)J{E9!4&0B6TCqKVhV_{J< zj(ZvQoNS#$H^;M5&xvfi>cW-dYPCDmZGzx;xVvEh2#N3lpd`l!l;kYv?Z?XXYUJ)> z&Z!H=f^MAZ8r7Cs;~9UL$gvi+Y}$Uw{&H2i>J@38{bBP(YLm<;f-=8$Q^2#$%dumq zT8W2O(zJ#b22JcA&aTsVR$O||A}qPju~bN_N*26~i1T!C;3KrpmX~gJ{umH*r+A~L za=no$6jPZ%#OV8u*y#o#Po)XmzLCUv>evy`kF!WRZ1^07Qp(d40#A}AJE6o%x>&1r zZSso{Yoc6Ph@0$J4}EVj9_H3$wa4`3|9Cs->Bq(ji3wqO%65K?pq|QUc?qqNatD@* z?^lcJZ*?>HW3o&p4^$-4TguP(>S&`ZA{TFZPWtQWCvpQ(ua3rwZ)I4+S|Y5EP_FyX z>lm*OlaPXQ`qfTRJN8!W+Z~xT9&8R-2V{Bvr+{3jJFStuGL{!qCvzWr|6y)`*GjUd zwn&1UMfS?Dzl7J)MzG5VmDA|;(Bn$}EhsYCcfCvN8*)Mtf5D>FkKb=!ZnkvNb-vT3 zL;K0>O{vYy-adHa12*{91zg)L0r&Q(<391vPkW!N->)W~d;F(szP~$ULQ`g_Ft{cq z{yrvUt#r)_XxFY9cj`cN@pUt}LqHt;g?vz0k6z)kA08GK4*J(H*R7jMK?79&cU z92R`Is}5WSnBKRdc6@*A0J--t;r#Q~N9@>Guq$0%h0TEAyE~7H#CYgHkj-UYxGO8G zo&2ki)(9{StzxCVzkghG^6vLOwKnrb-j5VDwdCi*?sk9s0^PLINN7%MGQ35(@7XxR z@ay5~3qa4x;T^wN>;y`;l`D|Vb#!Yz{fn4t;nx^p6T>c<#mkarMdRZ#*p9L5lZS`3 zbrLppuF#5x7V?5ihJ3|nBIujZf&5~@IdQF~@bHEZQM>08#*7JJ(GtOtQEEp{BdVi* z&T=Lv!+U7GnY{FiZC_R01CRT!Q_Rd$t=|9bx*;uYFEpR#RoeeeubDgPDS?FrJ#<*L zyj%l_w?j35?-v201tgAbZm8aXPipIj#hAF|M>b&t{+R}fLj3_jQ{Zw<89D1Y^7|#! z<5l0E%%RKPA8!m5-y4^ADlXG7hhz&cc!!o-508y6ugKk>+ae~5U!NHw%~}wOkDC1+ z^NUqNfchmO?)?R*^Wb%?@Eph5amvP2;k#V~OKPAVMfy_4yM;R`t98F3T-4t!f73tG z6bHJuJFi+uKOdBmB>uE6z3Q6Z#egPqnP^6g=D#nu(N}=WtX#>*ImUk55_>TfyhWI2 zw=Mjzg3Bna^X;tc_6)mI7GQEl4WYa>D28!M4YhWM@(Ku87O|XpB;gchOQEX2Cjb161S$!?L%~p%m z@j<19hT|66`Q+`6-@|79Sax4jXe@AvpE4h>A=N}YGw_DF)N}dGPIpy=7v2K*Y(pns z8@-P_975l;XO&6y!cX_yd7(MIXmtj=?1)~buT2?Hijb_yEQ6FP;h1vNQiR`{atKcP zQKOCML{+c86#WqlD5pb`6TFltnM3{9ff&CX2N4OFPsQB<@j6641eokjOP}4d6ME$vDZHw8c)-;dqMeyp55z=)Je`IChG>pK77`TXj z>2${#UlnkZG6ZfsFB{W_?>d#cdc8=t7>ER}?!I=?#$^PyHu=dv_y>U?HqOk_ujTY2 zbvvxoZBJz$hKE$7PJ_*8*dEN-M;cg*+Jc*at;bA^z#SH(yV~E2*wXD|p-7L1r^4=( zd>>d6%+lMkr(lKrH@6cux3P{W&J1p&+oC^M(fZ)>zMx-=OKTWmn9&Paa`{5+4dc0o z`&BPHISh)BcY$WUYb>CbsEwOnb%XRcf!1q-6WSm_{?yx7TV-)3JT4M=5ZB(%II*2D z;=^Q+t$YV=SWHiT|Dw7Ro)$C!TtsqbXYYC!f9{STG;eAyMZ^}V)e-bm5}ok)G!qED zK6kJ=Sr;VybME14;(&>l(bqDwYcW`#WX6{XKW+dYXT9_CC!M43ln||fv9;d2!M;sM zQ70W;1tkrA@TMv|m*I3>-&fUW7+9Nz?h)O&gQfp73qUCM#IxcjnA0|aK|dbK(S|4H zY8QVo>;C{|?bi?uAMd6)d|%4MD0Agf1Iq9 zrxo^289{bya+tIRsoSwU&^y63@L?Z5U;OEXQzy)4wkYFo9#>0A+}x7 zya%~>=h{_6bw`Rwj-7>MG)=*yHwJYXpp$5wP>Wk#yipm#N`EOCYp%g|UTn{mKFNH^ zQ9)G&n^`Iie3|v~W`DF9x`~j0*Dqz?&HKOug&6ErzClphHk-qCs(qZNOdTP#S$A7o z=gUNlVRcK<@S9zNrWcrRorG16J5}HV#mugI^Mo+Vru%La$U^r_s5`D(4`~X1lt$is zYa{3VJq^3;?h6Vf;S5pw%SaF#yymcF^l=Oou8vyW0zTw@(q)G_MGWIjt0-tFKwyu; z{)&G(O=3$z!Vd9~RRi-xE|&)LXp+NZYZhy9X~yyk-r+MRY0h4}^5D&1C;r<%q^>&k zku`-UHjw1)xA)9*2Emir#dUT@5gu7FJX)voABpfL*_AqT4Dh!22J`IHZqwB&AhSCxVzry zbRwY{)0^!xY#G-LGP@B=jcA_GedKsE3Qs13?@iYE@-^R3svEziv%TiB`{fMVb1|G- z6D=Eyqs;Cq#?jvgf^Z+ugWkF~>h=pCxT&Ye$jr3(rS{O$c~_d@+sKuC#zN{j{cj2* z3cQ`f&hDQ@vdAO9Azy1>*ZV?_uBqB-u3M3YI}g=4*@`bMhoo}jE986E;gN+^LNFB7 zHUIRHNDMmha4ILK(^_|qv-lDy1p4a`=Pbb5Ai11e8W%8n>fHCe0P;Z?{+y3Y#!!ZD zB3w%5h-w(ihAZORX+iHWM9^4P(<~&zGSK6|vJJLisblSFd)W+jfRhvCh~y~nFmo%8wzQSLXnbVJjlsE3o-(9HJKN5T#? z$RFxN55lrfMk}JhzwI<;O*&%j@p?;XPAnO|^&Ov-0~=;|6Qa8zyXKEG&siN5q!&E( zc+kiPQ9fc*t#txaQv9nOW=1b!s_xd3hAi^-2ywP8s@HL*7UG0<{D)3sn5Zs0z?gpk zjj^ud7pho!rPuKIUjhi~UpGy2m~RC_JSH4E@l*45ddViIm+J6L^N^WI(qUnk0MX^3 zCRV4h-TYfL%U3Gn{{|%h(;oJK{Pr_;GXvs`{fue4p~1yO_wLou*44cQyg~Qtk%Cmc zK40ry4E4~4rF`?iwX2Y)M}3CRY$oVn)ktQEkwG(azgb$!>;{CHp@MgBv0D#aXzYwZ z4AfXmt~Jc;V3tkEdH7<_hKBou{$F{GpZVnXDU`}iLNi-C&&&Ev(^h7-*nsP|3L7Zi z7(NfbR~vcIx;pPqGT~{Y4h!R~sE;W=`QfNcHiUjJ_H;i$kj zh{OW!?cT$Q!{m3PGb5A1opCJ^!OffZ=GjS-64YEv&4cm<@YHvCQ^Nv-9SEc5>XWyt z*Kc`;RDb>GW>oZXJq&st{NBx;{pxEI_Hf`*1EOj!(N_4qdi%ut1?_3Ps)O`~edWgE zJKm3NMC#aqx`$;qcNCt+TRTCXqh>~2!BxGtS#kSxS=WdyD`yOWi!BoM^{jr#dopRz zq}EQ$)bmAp$5$-6enoZ;uPQ785%I|#c2Y|zJf!#<4Bz*09oFseR;oMsQt9j22dIV5 zs$uTf2Pr5XkuBNB&>#NyaHyDwRuGma8(A}#D5m8YZBS%ryY?%+ThdbcmDj3s$)S8) zQB)|j(O#$C#`amUfizKvT7P-D#}YSYY^=4T;)2tI2v&~IA~81ke=|@x3A_O}=S^G` zQ6C4w+;PdHUhe%a0tu|LY00j=7X}K*A8?$Vj%=rqOLLnnPvAA>E6O%|V2r6<{o~u1cYDtZF39^-Wc2I|i{wU204no#~<=e=*Rs5>~zG3~7_Z$VfI^TQIMv0=S&0LTlu@SuXsa31ky6C z5DU*OIilYntc6GzRCh!x#z0urI=}x*)-%IjC520~ZZT!(>;KD{9`Cf-b6WrH8Qo9` zG@%E7fshaPILLp}q5pdOnrquAJxF7>OIgqHHZQfU0Z~L-gZt>^`;+NnvChHoHFSd> zqv8S|Y$GoTR9uh{GELC-a1;+$Zfw6ByN@5UYMsFg#buT2E%oRK7A_7!++Bc40lSFAeA11Ee#f{|gvB7(`IqZzwIt6pj$lf3s-V}t9pw1ipcB&QM=hwD6@NeSk zUq6RBSO>j9)Lfc9-Vi8TE!gRZgku8|*rcTiy^MFX1#Y#eA_%mrj4RpP(L6~B!Rh&1S~t`AV9vUjT(4t+=?X*-(`m=^&*-0f?)@%e}hf|$4u z?7Wz~LrYW8{|0qJ0JV1qtG%Nn?6gOSW2+i-!9F7+FDY$jSwx5+tCyC{N5K*zJ0q*p zVRij>z}Rc9NE7g|v|~g67wPViM1CeM0RdMrwm+?d=KVQirzRy#H+%z&|7ez?r3Z~FCYOG{B%5kbF*luv6NUx)Z2Sqp zvHKXo{DRVTQ1PRaH2kbiG8$@@Emu!2^<2E135|avh5Eiw82>-ww*RxgG-Cg2>ja&e zffyNT>Q^zL$7;PdRIqpncpTJglF=uhew;BcHi0i)V{YN6>YrB$GCW@8$=|Dlx+xV? zq#sYUgNlS4Y8JE2(WJrOCno*!My$%4#z!Y3T0`PH?mm3L%@JVHcb-v80D86-@h5wH z@TT$PxwHD;uDYJmaM=>k@-pnWcBy)!g)=N4S(8`2Y(g|Fs9mcGdxf<-AF#*-tk^U` zWSb{nkq&I*g*v(P0-1jJCy-yfQk-$T0XIHzI(<%pD?w?NOL4zlFlbTY3cwygoSo_+j!@R|uxYCl|lXA9G;! z==fS?^}Vjai}7M7T4!QdcDajdmwLG0*hCos%VSOJ_=n{^r2&>l8})m6QT=CTA-KD439++Xy-RFTO`#}c_=-sof*Ou`vz3^=BOa0^G%o2}ZH7o7+vl2niaXUxZ zNe+U@y3KZTATTb3o$0SQDwC&){S3VFEyN&x4C4OEITm}@5g{enndLb7$9L`lp8z<} zYwj5GL`X&A8XC2~lpn_FQ2nK+w%rYP%BSHz4Oe7XCRgkxBv(-twOxZ+U3wvb9f18Q zaCulE-?icwz++Xjb6Hp4=P8C& zzRLD!?*I7>>@BVFC({hjI1kuQn7G8$bqTnDEslAuUmRzwD`Z+cI;)N5lI%aUhRkGKuClg8dej~)Kf->a$Y23SYc z`Ivq{MRhJEO{Y2C?8z$(0XJ>{kJZ6c+?0TD8^&P71~{`cG~72vz5ZJ=^?yNm)=lRa zMg_lNByCWeZI=TK6@wVVk4Tsi+ai;Kk`@l#O0q{>BOXAfFmh8r)F_i4jL})%xtAY3 zE^O)pe8uwLX)M5+8E_SG-VTk*4ixJGUz3q}BS&;a^-3u=6ag{(HUwb5w%m3nsPQuZ zAz{=tBR9?hdT`CPg#8OT0y*{P|N4Wgc<|MxDwevIl?!McaKj@RrdJAsCOat z$Q9QjT)sCnM5-9rRpzCKP?bS{AteA!Izkngiq54TydZ;-F7wh?!$0`d6M^ZY&{E7mYaRpqHttT4y{AUSh@LWLlP)>&EQi54JE|zj<&8I2qO7Yt*DQuXdw3hCBWk_s4KQ#=U_Q zr-SM_H!Z+tHa$c~dn$9+L^l;W6ugm=jBWe-&#zwm`+x0^@u2_ zYq06zEqWd>MvHGkh4z3kLcG+Q2lJwTX|?Y}{Q)^APW}QP9}D$gKsUJ%H#hM_vCn(C z5L!^Pl&9)?#!ftd9#RDwEaK7+~xlx7r%M%gl5g9y&K+JtLhwQ zy1pg+bM-M`rJ%reBv)0$L^1RC_FTkx?)4!4e@uCBElZoXpXW77Q+!%Dsf zkZ!GSw)1lTPirCf1GPWdOTZ%Yx=MP@*l(|`X+gE3Y;QzsGYB|Ri z{+~YrkmP^1>yIAl|04Q7*7YBp90wEs!O8!xIGHpUAPyQig1aQ=lkvx~VCC@3B)o%X z?PyBHQWtW*ZmSDaIpMubq4+H*@(RLsO5;ck2SAb+%2EwP`(d_7RJJ}4^;p@rg_$1W z@ojvOLiZ+K&*1T4BB|njV;`b5q`nRLk;n6gBV>1m@XnDwW7R_0R_Bd9$K5fsY{pu4 zK)!h#DyN{zf0jn(38d|HScNtmjW!grZs(gG9<5kGT~oc6 zzF{<^SMMqamtDMaTg5Dx+v6r)|7f{^vW#EIjN}ntZnQqb<@A7&?7fVvb#MlfrVGhd zC6T@u5KNoP`3ED$(6&$ym@Q*h;^>f5MLf=~@ROI4>?;j`fZbLI<%izxQw_TyF;=WE zNM(h&G~WBSto%0Z$&-+!hh0ZeU0Zte-`x&1`-AfCzh>*xi2ptX(;{% zO-2*J9^>w~yCY8Yz%QM)Q1TmvQ;*JQCY!A|!VgCCL|+9Obe2A)U2{!4_}i|-$cumo z>y((RJHPEcb&6NN?vNlNa&P2Srt}Ly9RcDke{!wq@6?4+{}TuGSNmhyYhS${U!X1c z*k@2}%?5k@eBi<1H%-61-Ct>xe$#bUBRg%G*{ckPs<=1A87M7vCf}YKl{;9`_1!5;180aDBmyCCyIH;nd=INP&MD)5ICS=|z9X8{ z()}tz33QPl@$=VwuWC!}1PLkVWckH&PC$N7d7~#-uH0}X;0%xcBPZCM@biyebxEw0 zy+J$|qm!Ui&qU;(3`f@QFMXj4e_xM7ueoE4Cie%wY|W}&*3+aZm~ER#CF2?>?i(|_ z)M}26R~~IUTvFhR%6O8oD|b_g&wrcd zGWz6NEwLUmFHhZ5DC_h}e#3GCA<1=>(cCg9|CeEyT_KdG#LS>Y^8+(imKU^@8dt{Gna1w!SUeF3-^C^ z%IAHWlU@Z7(-8VvOq_%Mn$p_|Rnz@Gd_uSH%Ym{c2LJW9dCo+MhCOvSOOZb*Kg_e? zVEl+E_mS;lvDnZ}?=8uhk9SUM`1VDoKKo%S>oICyg`axMz;nNwk9V)_I2B;O-`SRR z_0vg6hEP%fP48LuqREgPf;?p-5?Ztz@&NUU%d0<=ry^aNZNA~``#oR>^~KX$@_Q1k zpIv$ATK_zaw&3q*gZPc!i#stnGFjgHzN7)D$i(5C1XMGn_175Vhdp2m%3n8gPV_jg zN5HsY>w#=S9qpW|?8j$wH(gAwo_a(Sf$gO;?;Js1$O8ywjP9 ztyD6zb3=koa$41hP5yR;{NY|V9JT@w-vF5bzS}$?YX>)P2`&$Zots5GU>|q|Gw244 z&OFY<B8nzn~O$xg@jXxX#~M*RgF0r*kNWHcl0{X*e9l|rP3dYI`* zcN;eVG=35?mnZp}qRgp%X8xW>$=SLUOA(OVW>h3IqZzBLU18;jP1!yBZdr1&Gm<*g zDtB+4C6M=U>P#4XHyEC~`Q$IVx(WNMpWm>^76?6(P3#s+r5?_+r5%jT=gr{ZR5Cn& z3ZZ}vfu8gu<(n4rE1Au`1%4fXknPyl`4MH{&sUrf`$G!*>R6mqa5)wytJ-fI*dFcz z&!Uo#2wYyXQ5ZVY-R~6gT<})6e8ZA60q~54k6V3KsjHKeLP>06x5UJ23vqufiyw`D z7e+~pyXg)bsBRyMK^ubfJ?%6KeDo9DfIWMEj-EFSs4H7W>k=O$4%Q3n?Nw;Jl!~~$ z;p_r>&x+haPh+)#vSUntmM_hqQrxEJOAeL~w_i!)>jnzsx6bK@{1tXgQg2xujTNVW zmK=AMI9Tual4~amyXziw>h|*x-cx2m5fG zDcn!w2%$MdP6VEqhTeu!iU(||2zf&i@B@ZUv;y#BU9>aYviPyce39L4gJ+8X!@%p# zsl7C<0;+bg4py=ix?c zcRL&DMr_rfp)ILhILn8y%$Sa52nu z=3~f=gfW)4+Xv7k-4RuGjY~%Zrbl{3EAG!9L?mQp@YDk_3sH9J1Vir9>;8sKaXvKB zcV1dAAyyN)gD?79E+9@6>EnT=)|%&59NrREtXZRfPfPpogP7qo{0Q$n@CEq$mBCr-!kn%7n-N|! zGNm|}4~arDrB>9yFULAQliiC14#}VUrjHkU<%gyH(n7@%wh2A#gxw<4zTA(c%q)%=AXsgy#UIzKBEh-ITj#QI< z+E33K>?I#qMSbLEoL{`zyeoaU-fY!L4R`d-Ok|w{;*FgA?&YEeO5=>kl8ul6@~BxS znVOjzqi}nmE-J@&pu5qhul!?fL&;dY2*rLwgt{q$4hBL`bSX4`N#+tcfUu0|TPb_* zEYG+%1LM~jR()P^!5J30I(!hL`2c|{8$!jU90R`FUSYCfT?i& ze0OqYpL*0lg{~?lpE;c+cCiRvz~$0*_3g_=C4;rfQHP_0rQ>q4R4TqY)GzuaT>)w? z5#f`p*_yTKu0PP=Oj#pR9Ht63h=Z9CQh8FIN|G%N(-{ub&dEc-I7mGoWujAUGu3`c zikaxnR6m_*$FX$kuNh5j#@ZMfKN-DcBE2(gm1d*EC7@k3Gim86sWH}EW$7pI-*mCx z5{8>*`o&@wquZ()#VWHVYPTu{0-xf8iMwIEoov^*{kD4t7(QB~R-Yg4?!K3(d)|_F zZTFj`Db4x24OQ|_GN;p- zo$Vbmx93>N`5GIfVxx+#Md|;;-djdxwY6=-ij;(abfbuXpn$ZrbW5Ys-3`($oziuI zbV+w3rNBia-O}9+-@M%Sb3gmp?&rO~U*GuNG4>ezK#+BGQL6y~fR2AX;M?&s0P9Hz2GO&*)?=QjGW1;n5u+`LE#Q6NhWSiai% z(7ZWw+H|_j(EJuD{YFU8govF+i>gl?D=lJ$ki`{hsxdsRmn!Ln{8UUCi=IV1sB_my zHS|t|0IfIQ;;We0DKe1^>ybDCV?`c9z||Uy8*4hLq#Pq{#lsqyl7V{0l&Sa=wx&b3 z!x39jLFKo-b-mks4z;{Mzks@~+rqcFg0*e-&s{|nWnfL0?>1fb?MaRe#Cz~wI)RF| zS5;#Y!V;qO!vQ6y%rNZp`QpuL8~KeD*Zbqf^|vSQ7&D|Ei)AMyUnUXRUXlpl%lY~A zUCV{*PLjE7$ssg|-Z65$UBK|*(52Gf^|*fNK6g3GuKVud;b|7_8RuvMOUVAV%TP1} zLLv8-Hbwy)y*%WE(fie7f})fmCalw!)^kK--VU>8WL^63D(I5d7R1pEp>&$XiU?;I zJ*l>8;V7ygFKrr{%Rxp9Lqea-bFGAaGY>yB=Jut4;so5OrJ1YYBXVd=NYFj2R%yN% z73`siC0~vo0tpWCh{UnwfYt5wscdkeUq(^rI!EB40Q`2}Du*=&pU)KjUyCVzLm0x| zSZqhDMFmMWz4_+2*NSAQGM@^Z$Y@lKaqa#S8x}OtUAX-WJ0kKngC#dEEvjtMd$;9a zJB-NL1yRnno|~MKQ?5~H@_>+Xxd~~lZApspM*$9xm^?Gd&h%2Nq_3l zB5{O?d;EN5LN7|%dC=%&Q}ed&!^L`zimYB&O@(=zH?4vB2u^7H6R_$QY5g}%-Da_U zL1?Oh7sF_5frsMBhv^T5u0b)WelbOMiOC~bp};cUS%EWf3ew564chFjPt6;pRNG)c zCUXj*TP5(wJcBxP8mEyu%TyYkOUOztSrh;f6|+gRN7?QfACBPhcwSTt=Ze&VqXgZQ zGGGDzD0D>srER*HL)}=js4`(SBpBv{9MuO(Xa_omHKcFVqDo`(00%tJs~h z0-M)rsYo%fxUrWeS4WY_0wDKwn$|uqfzo3;#k!MFso1lN#I`p?_gQyV$|MAQoK%Kr zV07|jRQr|R;IMS`U>c-Ao>B7X9sYV`xT_wk<)ucYtMHp9LdDTH{0t`)} zma5j_5gTG=f`<^6=adMA?MFsJ;q#%ne9hO}mIaqR92o5tXP2;yI{FKr89BDKEa9PW z%JnxDB4sEVu?`=RJP)rLjN=Sx%!PU5yU z$F;7jt{%=uM?+F@%85i#ZjcK9m@_Z#m>&G5y^n8E!^ecOIwu8+!vcy<6mA@JRq5&7 z<7JCU1J5UcSwXnMI2HrmZic7jRa|K|-h+0t2vf?pNDzoFNnD`ztW7+vi^pi@U}!i@ z*)8mAGlv)~$<*(^f=EpE!tT#(9;+Bt8w;6jv^$4LZK2V6u6)?{~CYt=tU(RZR}!xqLE+;ly`w; z!R9drC>1U_fpz2pwrzbb%(kcbDsQr8B@mBk@Wuk|WXN%lIltGk{UAIYmFttk_sfzy zeN)>N=yImL(TV+6oVSh|7^Kf#I#_vLKu(~L!C#N%UNp+#jwam`L_VH$_mS~iEZ+%1B{JNys>x! zVhc@BwyEi=z0O;Owq;NYaZ9uDthU$Ye3*-e5`WR4ycTEK3`Z zhW*<71|{1|M84n>mVDZfWeIs;5*=K2aZ=qpQ%t$-8FSdjRm^8TUO@Y}g#!F~Qlc8w zbweR_V+x(?#3I$@C})I%?3h7432|9WHcHe=vi`2KV#ai?XG~2~yA#z7qWXIsuEepp zl}f5?)_e5nMd%WP9Au3J?Ki^3d{7*oi_E1WCv&qt>{Ni;RiCV0f{gf;u1#H-2nMT9 z(Pt)GV0te*S<^+*fb&?V;pI&^tybGG2$AOB^YkjDf~vhTW0?GuO%O4_WKVH6R-1eBR9Sz6dNH{&%faau99qjx0zXQDcj~bfU+GgWem7PqQvWu#uo`Q`wnIJZ=7;I7T>X5?1Oz?h?j+9{V zE?}vSv}DK?K{;90FQr(AOFlyVsEUI|cwSO%Y>@$uIcoyW+A<1`R)HsYK6C90^~P+8 zOPK1ALh8jx{*BNZC9Zh7FSI+HI<;7yBTdZ`UD{IQxKGLf=lPIFW{4-=ASt z16s`MS}hWX;`?Vz2LMlp$oXWL9y@;Gawi}?!dI_h_m~c$9vTkM^MT-meiWq?Or-Y9 zV-Im?=}x15OX;!kk>WsP$zH;r#74}cc94;&u#{dSD&ohSYY-2}V=It9#ao*2F5~xH zEHijaXhwU+@OHG5KSftf(DT|)Or?nuXvEI!gj}D_d90-;@;s~zk2!n9$AfSk=W!m_ zSl-X;vSuR{{~U&3QTrHa8Q%!^i?7>D%DE4swx{LPs_3;yOy1mkz@c^0H>$gcD^Pm9 z&z{@me-!P;QXxtSVM*X;)>;PeYCQdLM6G<=Y00KaH=clD9(@LNfI*^Vi-av<8!vk? z1JqQqcMh$D;OJd+MXyTJLCk2CXo`6@@szCsQhiFn||2i0$&~v|lS}EN;X_@$lmEtc$PdSsZe#80`+b zf+a1M6)CNElstvaQPZf*dXwD<(@X<{^l7FsC#{zLoAB4h5Q7^B2D~T|Y)hA^sf**v ziD6k+nMU0TP2GeaDSm@oA|vM|W-u3(be9f`CeB$t#N%E?zX#O5Yc;OvDC^oS7&1J0 z*M0lkjhr5n=UgblrA{R4@I6|Ff384lBKx$J;O((P#)d4sc=mFok9=o}>F|#iiY?u_ zrd=gVPpVpZT%X++{4J?G1xckHqZSpB^l*4I(-jQI%XoSV+6W0ujhvPj(`c#_02aPH zEC$qzQEab$MXgJQ<wPSa0S&{8Sv1E9)8ahUGoOUDgXB1!p@*OOH&W^B)ZV{wkcSS2d4Sqf!8Nx#blN4AJd zFUPxI;c=DamkKgN?$M-fI%h_LSP-6sAZsa)MtzX?f?9$cZCIUeBSla%J;Sw&%>t$e z_mZ8Z14NP>-T_c|bJ8^T|Xyk&(sq$u=#sKLYakW)DD>8#ZECu?lekP|aEBbZx5fUgsLN zV!H1`RqQS-s8T-$_OZ5sMp4AH4aYsvBye##OUxd3$16um2;;`nYzdQ;Xfu zLb=AWzce!gq?D9;XF-9M;I~diwjf|HA>C7k!Ra^Q(fb)IAhW+9tmK63T!B`K0z#P( zqvVw^{>VDg00G-QW%+x8wih-1>v_%7o-u1D9@BM`!%4E!*&-j~4Hb3Yb1xI1zu(pu zW)!LyU|KovU~nH~$1NGki9ZuGA#nRq4q;`JZGb$6ut@7;ht}^`>_6bz8mI@Y1I^>Y zhr^5Fy#^+1D8d76C9t`f1K}n6hF27Dul42ynSYS*-P2m2=8XHe%>o*elo-_Ml>EZIgvugS-GK5=X^wA0 zR7T=%=zdU`+`2dF3xc}7WLS`6Nlgns6oa1;-sQ4>FlAEk>K7!@w?0qf8)JoHWT=PSX{ z&a)<(_)m`;%xU|sziachtvOL{D8RdlIAxL90MPNv+hQ*9#@$SnOW_EupP^$3LM2~e z(G(cc-JME$x#A}wX(H`85%^<9wZmxhLwq4_JeU;_50)rY|HIO=Z*!pILz1azIDhQmV zGMOnH^%D$9Fp-$q0FXp;1|phMfB`;Rd-j|U@+mA0Z`}d*mgUo8-EA^Hi&WE<9FEuw z?VaY^%VtBV0Fk|tFbxustTk5J7;&y|=1rh@atoX#*Qp(RFW6v_<76JVZLFc0I2Iev zX6z|k?jK`NB05s3dJ4K4q`|}qL;Dw3t9%}vuIsN}Z4nVMukIMq>jKXMnxfy`&qGY# zgwNLKG-psJ0k>10K+Vf_S}ok^VTmX!I$-E*KAJYRC&VvVXBq0Zv4&3Dy$hi1qZZt( z%OT|H0Nah4W`}SiuxpuQ%8A8^A?S&~Oz=t`g8rXda5H5%x6S-bR;2arrxl*f!t$ZD z4WfrjwqctFEfVl|ge_m>Rp+QXOl08~LlU0M$lq{nWJbwsklGChH2x%-!aEh)o)vid z7UE=q9e{)?Kf%7MopdUR6~J1w({;pmy)MhF+6BT_FR-1gi8cw_+u^FL&n{ek*K$=;5r{S1M@e#QnNB26oFGK=;o zLsDKxli|q9rigNV$(vP{Om0!qYNQACnp)uy>1hH6_TkpZ3L?U(c zVSBrARu+TqnSRyI8j%rL1J7;MaeE*aAl*EuwZ>9ZA5!Txl;t~9 zWLTrB$YBh>^-g%p8W>Idks;p@M6(jCVFKeX-xL~8aQ+HW`t5WZk6G{OO%@Am6ub8N z*W?62LUrnwlDT)Lc?_&m5=nTFw(B;F8b!Gmv6wg7OHx8qo`2LWhlp5dq0l=IC(W|Z zZ+_bc?uYgeh7_4wVWdiYjrcb+?O)#vxB)gjDY~jLn{Q`$ilyrkvFMK#U^naj=c+Px z%VYvHQQ06C87#>cHqQibSa2my@fMbD2T+FPx?_edcWE}@rKj+M$k`X4Gm$$EDz7Ua zqK>Jaq?286ZeWaw-03AH+_s^c%UwgYq zuJiW_>_wMOOIV(1E$3#0WQkY`9Spouxlgs8gpvC=@;?x9w%BmL=F*H{`o~ z+YBLbKU-j45#!-%FL*X7o5*hH3VL=b&nbzbrg)x0*|oBq+z9Q=1CRemo~2AR?ora{ zsRjWF0EOhu#C226%Xf2BkP7nH#Zk|74$~8Pyu+N$H6T0D8&#iL8e$!qi{+T)yg+d5 zoU`Ud0UdLv(iuJ+QIHKCeIV6$eedlG_sgPSx#IK%F7FqMwnxcqG{U>@n0Pdrqq>Vp zJ>w6uUJ>4%u&t$*J_iG|M`3LAFw)QlIkuc-BJCHC5vE9Q37ynL(0WYrV?x{4SBYVe_TC3t$SpwzSNsH;4>D2Xdy}oU$ms6YcS4 z7m+;Ia>De`p69?4v}6TcPA9beN%Nf^*N7z2GJ;9oql>)=f9}keA&5}N zP$6a*?s$C=I0mUu_&Em>`isK~7Y(C7uqTAauE3VJfP91~Kr}CgALL%bGMWC{skxh~ zl*S8p&^uN!|~bj+%DecHc|8wA>?O+`Aww&JP-s7rv?aAq+8)7#q|EeFgRpe7b~C z{weL^ZxVNduoPLgD*0*mXu6sE$e|-gB%={K({qYRx2ulAdLsp}VIm@tw9b;F4D##T zMC!1G=$wfHF=oxfCcAMS<6iQvKR{I0Gf^!mc;PCAWCgODJ%5!ctko6p;${9Vd+w}C6rCub2Llrgd&@Apj!Qga9QEI#L2TuFARKXxQB1V_ zWNY4A54bIgn*l1e4v6K z@nvnYbF6ip55>oibd5vEe~jay>q4&v6+~Igx(Y}aMHkLKvmLPCrmK96PDEd1>Q58d zXDMDVhXYQN#iKB+=oZ!>Qbw;Dyz{V;rah`PPc0p7ibqjsX$R2b-$JpE528}9^o}!a zd+}xOi?-QylefXT2pbXEzj9JikE{J_#3dt~a23PMb@HTSN{S zh^RBzV~~Z!{03`=WUwO203xP%IDNhIVK>EYNPLtYs~59ziTg02e`_^-nI2#f@LgFK z*j{pjj5z3VZ5y&&TFVDL47;g58D2+#GpS=g8%i2rt2@$#V}*P92#ROKwq;S;f!Ao% zUJp2k1%w$E?WQ^~V$RWKOUY7+Qz0nl8MfmDqxDI~miK$_wECHzRfZ-+p9p{p>*&VyaSa7_TZa z(o+#;{t)Prf)heIR)uTbBi|)VB|AzCD6mC!i6fx zqb*Cr`7t8{T)DiJ%JgZcJ=17$O?Jl>bfl20t*?^FVOfY$)Fo+ZGwirO$3I5Z{SfJ=9+Sx#py)j?ABpuWzx%=-{_HhUw-7fae$l4&uj3o(M1*bLu=dIkb({~4Z^HDZNx zC`$R7CAn+$=oBC}PRc&~KT}Sk;gmlO=W3699MKwc0BkGXvFb4WcAJbyhUdj^{ZI)@ z9!_E1h2M{MdZ$uLEwA{7k?Ctc9=ovk-V`+>(ccDVK)(P$!{jSrK(m@uRP2 z$ZrfjS_$_%lB3;`T$)UZC2d2tFq=K>Ja!&1{RmnsN4+>+vc`l^myZP1s}PsI4Rxw2 z<3Ic}JtbV1!>rTlO&w(mcac8SbPB|o8n4rzk$m!f);!OE{>hB1c`x715?^Pod(fR2 z9LC74$6{H8Psmf70WtpE>bbQWFjhkHyHr2-IO73n0awj|SsWio z#wC9vP>dL`+(Cx`Tm!pOR#2TQ60MnSIZ~%?kX94p7)`FzevOXnBoLYq# z#k-Ex&nh*F(8MgfMk;JK0L?+_@uXwws^vxy&qLB0fTlnsAJJ0)iRbhg5rl4s-f8xE z@G2L^4(l)*LiSo4q;WnmK?)XLQe$~`yJg5n9C>*nKDa5ug2L{jp6c?F{v+_H+x`RnkdFK{>u5yYkM zGxhmEZH?(Eot1($4E(a}NJ;I{LmN0R6Mm4k(m*|ZIAqZzV3h6SXKfCi0`QMI3L5;f zvR6R@uj*%e%lQ3!SB^?m2W$1ru%HVqbS$@{py;W%wI)%Q{42k78Bt7$WQKI=N&#Ag z!mC5RTie~M!~R`iEwSk?k=e6g=Tks2urVt623#n(;&>59HYGu(n?^9j_pyc**3s+! zv~A0&kVj-21JMtIT>4?hoO$Q@SH2=^>*s(5FX^Yh4XiZOrEsNDHrkvy6d33TFwiF2 z9p>NC;Uu^Zf#lacSZU>ll zP>7bb!{gei(UoJHgnqDQK8W#ml0!_x(HB&fWAnw+-z62Euwv3B(^L-OKhJcKW4prF zzayz++0&WDX(=|L?=06OlS@PFA9SF^pmYOc6Vaq@6!D5SJ|iykd%?aTI$@GiAyVqoZ#pyyt=M1r zB}`ar_Yw^+^H@@c9w*vc-9`=sfQ9_uStgW2h=LB z<>Gqy>o}k9gz*-MnoFHaByBV`pjFUzleGn%Pp5{yb05p;xHeY{f84=Yi`JfW_(JqL za8pO^eLBfgaX_zHnsc_6dEtXsu(=%cY|aj2h0Oia5G%%C4MWQqM%C@1JM~`&)1U>D z*=XAq)f*e1J*_D$C&;qQ&6ia;-#$kEJd-NGY zsrA`@cc5cT3z(su6+oyL`uBLI);e;Hzve4_{XZ@2U;NfuP(qrt{u*(uMWE%o3+IE- z$9H?Bv)g$3!W|^IXvqt(5t04UaJNDfJwgtj$Uimtrt{!-7Iqaa=qUfnS}SA(mof<$HzkEjY+R)r-_AcJiYI5=VtLV1@-3#!@T=JzM*y!r=uifddzO($>$Ny__fBHLmR&Yc1fq%%hnBaDu{MIP= zjrt3}0vpSY?fEt0Fb5YM{#W4UuZsrC5+g8*Jyq;4A%i?6lBxeiGUHD-h8bTB9`(By zzee6&QQz?X$v%F+rC&d3k%vvQ#Hs$;-7IBmGlsv~=D(l(zrWFHf<5TkUl025Isf_p za_K)WQBoQ<_pS45+CA=FM``1q@A|JCgFmkK|NVnWMX~Z%xq0lyl5FF}e6#JP@IveK zzfbS)b^v~oRsgnQ;}aD2TvMOLgDgXvMoXCV?^pEio%JgEB~8HgOH3kP3naI`pFMHy z|GmNg%Qtn;fDOHv__gnO-uG3-ANs*5kNo?p|I0UP`LN@!_iKha2j!_2l{f*4svOdP zcxeAuEUZOc=6u4q5yvdjPZ&(Wdo2ALfG3vra89Z29oH23y4kNF8@zA)D`ZnjO63om zZ?1U|a}S5uXF5%_tvFQVNFVU9-#*!7{LRt-{tVJkVb5B=7xt{@K6dMo|A#GvT_Fm2 zdKr{ax&|0@%r5fcUc}qlX1%%A&~@Xs;VYkPbfU3jzUR6Fu2AFk!;%K$iyCc<8|xnS zls})HzdtGixG*Oq{)cXg#K$u7zyEK4_A6f2uygzQm;EfEbld%B@4&C}%LgU=lIEbp z?RNj$L-0ShsuLCZR%$w5u@}@Q3FaLQc|G{AXc_>u> zoRr&tPRc*->;JR^|HPDkVhYd?`~PX6_b5*TT7xSv)d@Napb(foW-7Cv@SoGMk|Wq8 zi~cXIC_beuWaZkcAUcHnCz!u+gnqkLnk_7^eEdfZVe%tJMME7XoTRf|`GbDPq1rwO z{DfBe<$o8mf6nRtM;2rE8dhb6{K|0@%r+aVpGg{73% zzfwxs%8m43hA+GpjiUy}(f|2x`O_3&t62P%9jBVaz$&5Nm@R+Gt^W9MJ0;i+m;G_W z2-qc`o}m55eDQCJr~iC=14^)^eg9)=xKf}+!T6u|^Ka^{|CnU+tHA7lDhZY`^C>aj zr4#?1r19JH`3nl7FoQit)+s^*oW*)bjN3@^Q4g4<^^yY4DL$EukYro5*EXh}sI>oC584&Z9P z)q8Jh`!`r@2)-kbR8VLDs!u!puq#xLgLmCs%(+nv=z9r1;LPMWoA!ab|xQ3rp9WbFR zuX?+Uyc|5hF&%C(7I|1x4&MiKvKDcR9fhs8O1to?fsA0`=uXq+P6v!SYkHN!w0K|E z@pYv-)LuA31%3=o(zLKtq!eiI6Gokas&EmAuO%_{{U&A5Qh}Nbr1?7lF`6$z8>TeQ zHs^Y%v-zd^g~POU*P^kk%ezvbY>tuqfhioPq%2;tt7>Kt)J!*gE{1T`9Y;5}hoaBk ziJ=bVr%MM%pOh^pohipYd?s0@e=G4sDAw#_0Oqsz(Pcsfib3+U(aM8GZgW>F9=u&{ zq0F}ZjeXZ6!-uX#W1Bm@$v|Tw;gs`qn8%jaR{gDX{C3ox`w{$%cTtodt9?b0`T8Sn z;j@--p83>={I|CFx%$n;@=4Ve_$=aC!eMlo-KGl?DesgN1)dvxzwbbUKnFAdtM|!D zwOZkp_VuYdHvq+O zOyP@Sp8N*TZc;)(!P)z{>X#Cy=|0hWC^QySZ732(l8~z!OOm5SC0SA?_mJZ5-+VE` z%VQNHm!aT3%f+G;$=w*m2wxq+Dg5@02^tDdqxwBGqtEo|-G|>EfPVyryjE=$AHR4> zXj7o2|NVe(w8ql?4VS|X9uA>6Dn7Tex}$}g(nT~j^U&;Pqn}zR$|u`XhC0!_H)l&G zr#nBp;AWK}9yd!AN!>@^5Xc|>kOc`71!*ZN$!xhhCSM}v8H$GC!?%|w+oPAe&8hz{mU`ul z>zLlO6voxd#I16@T|GC_;B<)7t*yBguN%gsJ+eP4NAX?gSh@ThkO9lStnCToF`;wpPV9F@7DeFlAUxObk& z-ua5+YvsLw9wr=&m#28k+sf5b+C{cmTN+E|dUL)CaFiFHX4E1b)&?HE@r98{^NZB$ z^9`Vhc|h?-SilnjzZh5w-xD<&Rn8~lPV#^AyeYiYUbir)8js{gCce~{7#VF7+{aZ> z`|N(cjDM;M-t9Tko4~BpA_;Z{Ib0Fn zu8}~g8};kWMOkVRJYD_imQouZY4H#9ar!v&SHc8ok+MS}ATf1t>xEEh;0Cy~RL$@nTX*JcuUGe5~H(1!4$WpKA<${v?eB23U@@O$!AIAaZS!_~S!F)OcHBkok&7+EApj#qe^(5AYXJP2$>9og zWx0ym>7bq7|K3|FRe^CQ-_QQlgxsd6x~4x0<$!OK1MgUt6D6@0iqFR3C{`_f-TGdp zPJ3X1ha?2No~8gCM4WuS+y>LJ7stS=llEU-tYzp{RqJ*=sNS7VELvC==3eDH{~5J( z;&FR(y|OIU+x1vDDAT?Uyej=!2)X0Z$M1}OoSyEwyjg8-ZGPx0F4iDy)g981wxV>J zPL%0qV>sq9Z6%|%`P#!)J#R(D-zuT^^XrQU|1S*KG}Xw;4@l15yD{n%zAJanP;*|Z zc;lt3IM!#5I|;eCR%5m8M}48zVg3QfH_FzM7R65<>fv;>MH2 z^UoWcm=moSrBfF5(&%K|2$PVVM?!zHqL5qbDJyY+OS ztBbCHH_K;Lu2iP?J#*s%bpgqx+F}NWJAdA>Z4e-VQW$=!S*k+@!*{)E{R&Rui zqJ1F6tDGSk9@fnlO^8!SB5}dZ{x^<2f?tL@B1j8?K3gUwpNQRJOQ%(b4Kmj8Zon@W zzF!xJ``c6QL0G2pPPO=jvU z)jMYVYLH0U=f4>A#xtQXtSx+eHJ+2|p1-%ntt9By@RQx;#EcRRD||6aB1YbPvVw`h zBt3i|t-nIkd7X+??;GZl{oC%>LJB!Dq{r>pvhh}@pDc4Qzt7j!j6biQPA7Amzq=H8 zl@ZSSG99zzBRg4zJ;=ZEfs}D4jz1=fKDA7;5%<{@kRMpl-#A$ANwWP~E*p^o=07*U zsbKRr`B9jMd2-!`!KaAK5bhel$H)MpZ~mSud9praF@dogxbtueWT%Gl0M3zeL)mL@ zD_1O>l}>OX%p3)MsSp=>D#tWTHR919Ci9Pb8xHA!O-M~(jvoU8Rz>UVLM)6Sq#7k4 zXqQifA7eE1w$*d3>+pF{}We*?(R<0alvCCB>i%XqEbYsF1U&CkzI zy8_ib;*Ua};VK5(_BkY}W=@eaFK|^_-ri-$AU-xxQIaOQx)kiPWhCCwtJFf^6WPBv zaYEts=H=PvAad1xw2jTlO4EY#!&R4{g_UQK_z$k)eGiDu142nSU#>_8MQamD-Qla+ zrzGc#)vj^61_Ka2s5$Y$OkdH>Em)=l7CHwkYwkL)efGba@V86#ju8w>rKbM7c>HYv zULLpbP8U#yG;K%lHu&m$esgtH_mx-rQVJ4Tj2C%*+Qli1Y5+TwXVs?!EWz?35)b8KrchY%h?J%x|sz1OH} zy7WD%Ks$6eC%g{CQMCpC=$ShEg4GGmufA8gz~5a<25o|WE_Dn*WLga?@JX^Rtzfba7B4p=;2A;)Ft4zi{q<^nHY$0UwKcx zLp1WftkgfOBHqSF3*nK=Q<(&+xa3HjS*2H+6B4NO z4KEIU+*XRaaXY6^*1CkRawDU|ThjUa=mTRYRlqvvV!Mhsa8WA_pPB->c&N#2Lv2dd zQ(1_ag&KE46X{jjq5p%g7{eeGpG(-V6T(?GUsu)2eJZOVYyj`H(Tg)qgtuCHz}hf3 z|CT0FXHJv4;FMPNyUACx)W3;4VifGyVxlE*mgo)y6b*Mi;(s<1%uOa@HE|T8;CQeA zuOV|7Ot;vWi-)s9O!lONpxOD@yOb}qu?6?&$hw) z>sW%19e5!KZIsADNL2zOcZ&(D&JpNNm?r#Ab&)YYu_7~CG+#T&T00^M(2xWWnlYgX z;b1;JYh!9v_y*kWj`2NHK}=;&hNAu|;H%!XG6$Huw+6YPF+?mb3wUr1$%foT(=gwA zncL7nW1cn*Bos}!*Tu->-oU3;B-gWY{g?kXhWkv2}C5HQ7cF! z3j}q(u(l{}?!JIWSJ>PUQxKg2!vKZ(TBJGsDexvKiDNKqIPrEGB(ETSL#6;h4qoDq z^p%ZMy`;=`)q>UNTaHQdgQ=+dTu(JMlnyc%H@s~(@^lYYQO#_6er%hp?J=30Lfm%t z;ks5-)L&I8i>{MDnV2biS8@Q@tofkOZSC+RAFHayJ>?HUzw+!1e|d$ByE)hFF%s^< z`P4W<3xQosI*Bdjv6KRZCp&SrRDzDlydMhR+ENFH^yhUob)2AUiUo{mnIO8pEw}*h zn5~x&w~+omdOe8<)9SeWGy&9RbN0BE(H4y6tr7Z#JN9fSx})h-`U3FxKFukqs4wrY z-3ccqePtK%?5QyEN$^Kx+^~)Vkx(K+ITXCjNA&R^U3!o48VgyZUFW_soB5>P*X;4v zK4GL>&(u;q%;7bp>KwM8NfY%Gt}!&Iv`|R390xewd5lhfLYW+<3YMf@|0tgGa;qGA zOmT%Sg#4Z00;70wPYARe>mg4v`Rfpuo!TyDHe?mwxQX^XZenlLsdB`xM2{NY zL_QYA;nrHBVn0%E2x_ZG^kv|vhF$bg>Hv9zDhd|LYg(YAuWJG_e0gXlb+VQB{!CJ^&zI!KP9he(rWWMd*@N}5?kL)lsoIiRZ9%AVl zrIOp$Sk5&Kg^2IdkKv8-w|_MC{6HMw3MPY*n8SeJ>6aXcW{LNnGJ!2qVL?k`Pj~jGLFpG(6&Iy-ADYGa!Vhf3Dl&_RR+tn4++iYjOybkA)YUW&|;%%Si z*BcE!wg)9zZ3koa**-}pCerBb4+NamHy?9!-sCK0J8zAb`r#|C#e38wL}OUI4ClK! zQ%ppsM~%xTNCFd{2?Uvrbm^+4I-OAcV7%wO=v=eh0fsfk!XcVW)BMe6i@{}^*){8( z&?fT$*&u?FK(s^s$qM7AJ1-PZH-CI#X0Nd7qCG^Gl5dRstVM<8ME17V)G;-YHE;XdjYB4fiEvYZ#`B*Hmz8@jch zBdb6SKG#tQFarD|%$G?0PQy$AHd}_S#Z4Fa;AIa<+|mNv@U0jP{%sULTMqT_$H4L> zm)E`@Hh(~evsq4nK82Gn{@%dd1)q3uk=0R5wX?=rZjjS&0M>tm$x~rRQsj?jZMII_ zsrWu)gpR4DP`pAd3oVTcO_GiV!2t3p%nRf(U1ANjkMR-^Qb93znJNJi?*VZg0?XTS z8DxEfJZi zK(o9}UM+yTuS>Mg+0iEUhVbXY(|hdHKt8KkPOIw~esr_Pt@~2q!;LL;Vm5eO>%%=t z($$tw+`fcS#IYm?Gx^+;YrbiPZNXQINS7p5NG zuI=!~X8;_8|!EFF)+4VTTF7Yn7C!^2cr% zFk~3=0D7fz{gzoTgzoI&JVGoc!i;L6AS=P|yMBB>&j*K$*6#tSN!eZlng+9;TN{I_ zSIOOL!82ABNd!#?YpN|!vt;qqk@A?r`KxL-tH8K^uuSX=^^JWi zpZ)e4M8dBhcn-c%?mP-8oBt?GpRl_L!cuQ=k~|Va6D!U*5r3Ikz zeWpK*I1fd6PY{D0p{fg84L`cEimp>igU;P*q6pLa-V#Q}Cnay(++8GeqUfPzehROo zm-Ktox;0p>=vdqdnL0Jvtx&%?Q>ka>Po4!6rvV(Zwp4d>ByR}|gDRhV(-BykF{-6& zT9cbCW!GX>z-?e@JG#kB#79z~@C~g_((9HaoJL!c#+y#QcwPWcXN7dXl%U}tQ)e#- zo(V(E>1ml=XNY%%aqYJVABBwEZTj=iZ9qd%sGQpg|R(RXXq?m3N|Y zqsyr>2H#bQ3tE5RBgp-i++!WVMvT6>UJ-)8nRrnr`k@ndFGR~bX0c+~5qV!*9{=R^H)YANKWWKaoH6_Objy;PYuZUq7c+mI#qaE0L9x7vu; zir#Um+H1`4oK;kGs$N~BQk)m2@1f@gNzxghH&Cl`VAKLX0k_JFF1d ztvJ(bog#q9a8N@IyHUOO#x3}KI*#Y-+MSO#vM8kPmkD{}a&LC^xa(lLYD2Tkr>Z)8 zf&)JN**LQfwiDTYxBhe!q>TM6RzQz*AW!656D-)1{Hk>YAi zr0;9ZX^P3^wDHbRWBPD1hY7o?&j)@7;YGGtS`w9%4iRA^~7|sWBv=pFy ztqPv_>}nIGg|KybylKdN4+hs;!DsK``L5>b?A+vq~CA5&>W_;N_l&jvRJ zg1ijFerNhEaOqNswhD#pTIPt48IY|@WdEq@#KK^~NajZP;OE};`)6iu##bH-%*quc4y7q z(IzYotw>=gg)`7rF5miewa>@UcqGprrrQwpl`jdd$T6h}Y8yv3sW^Htg)Kt{1yV3rnnkF+-^-)-KDpc>}E1$*PF#;qlat z<&VEnJbg{CQGsF42X8BYhf{|FM|CNUTVoYnjlIRj_~B7V8M?LQHRKJ&E7BUoRFtcO&};jqiIh99_Z@0gq{JEBLfn7F@-uQ`$ScM)HlW7 zKa@gn9B*A=VBgWH3jY(#AZ87>m;%PgFXQ<}r(r&5lw7JRq%7;B;6oYSHT z$e-@agys-@FIV8c?Q~+2{a45hgn>ZX|2pML;FP;EZ;Jxd60Wv4dU%nVq&Sf!H>*2_ zlXJQ528EJ0DKoZ6b=jc0E-f-&Ew8XXb%yN|R-u*`@v_-zNKM!a^}5V2MzVhQ8iM9U zP;>}O1F@-{r)4FY_|WN81-^Zj<~Q zg9}KmLY*7V(@!64E7Tc`0PM#-0oGvq2oyAYv`!6hj`a=l(E?mnU*qdj+v30 z2MqwcBquJB0o2bTPrQdHT2b)k_ODcYSsk+x@vwaktYr>T$5yR5~ zn6~n(WZ=Q$N8GFY4PL8pm}6Qp3-P37Jit!kRpR4x_8TxjND+ANIE$e@9)LwT0-Aok zP04~o0Evcj-u<4lzyHq|dykU(TTRSN6_~VZ?ibjv@XLvh7F#&3cJHl;(ENf1sejAeVSQ|F zDB)qbgNRhZc=^7dzc|nU_H{4!*2fnf%@oXp4iJe8bxxdqf0n*+=jK>9_03ylqYhQ( zW^=ozo3q^Uw#w0b?Ci=HzgT>VkQ<}fPkuWb&tG&AA?I?#CiIRrtwM9J#DS>EwX$TZ zd$CQy)Gl2gZ`NLM#pf*E{3=^rG$OWiFw*apUR8DVj@Eqw7n-%TIyL2T(;#i-@{OGc zG~HC)8;Z=Z5=hfoN~4&3>1tYF_^J%ij6 zCVHlx5GD23*H&5FE4Nx;UvIf+`3K5WM}2=xTy9BQDgSv;(7|bW5z(`5BT5{wFS-w3 z#kPH>?46rbyT!ABZf2{XW}n-;-zElj}ky=EE~W6dUgg z#;w*28!hUq4JFtIxdji1MV!8|iT0Xu+|*B@Q9VvFgn9!bBRzKS`Px8=kr(fNS*NEXlbo3Onk_M)&415-pKdGsGe<7iryv&zmw_k9R; zbljP-@Y(_vRNw*4^b@^Pq`Gze`xZoTqh^o5;~pfNS7@QfrntkHmeYveT(Io2JpO0e z7y+wy4|ie87^{<0YYIb37{!w)<`0{`C9f5aITZ$Ud?zb2by>FQUv5ebpI6*fLT$Y= zqW2O}chd8~L5w0E+#WL<`GrmaT=;~27-e36K3$@Ub;LLQl9L%;J`bHHyl~Iz^rkJ& zj@Ep0D`*mD9z8JhzJ!KTqRA$nz{~u-wZKu8_S)A}U-B4a$zO7U@<@IUb+K%uop1%W zACU39%R-hwQGIRauI7-9aAWB7dhKv6`+W8Ok(;aMZ!S7{aiSO$L)@rDd9I7Z58wav zV8GQfFn8);?uJIT7O$GXLt=IycU23_y7?`Bp0;MamR$TO!6Dog$9e@Jn$=>f*bzhT z_Oo`aqv52x`kOc+S`?h7y!CWB5LJ`3iTexsR^G#Ec}w?p^e@=R#l6Cj0Mn&{;fWsF#PGIALF;&cRTGm(qM~Re^gC%S#;5miP?zmUrVXq zfOAVD)|=g24B|@UHh>S@@}W!v8$y8rTa~W2dt4$4-yl1^y!;FJL{b;evqk8xD(OKE%c(T zh3Mc2MqmgznJ$zlzR;2y6>caWW%;2W=x~&N{ZLr6B0ZKuq;zho_f`&VrkhH)n5k6_ zyYIs;iL|8`TD>Up!|U;DP<$m}v&ZkW+!ZBeFLzkm&Ze!w6xcpW(=Xk%Zet^6-wZK! z7s$=wwVDuhckGA`pcr&Ic;uqxkxu5K3pB@G2LEdhvrdOH#m@v-@e%LQpoeu7=_~vQ z1`oRyB)(?8CG)IrI!T%;wpAyeU0;)@w`#@YVX*^A6tUn- z?1nXMdeddMWvI;aQe80K*!>ZSFH>v?0g+5C%3A!cj=cNt=7CvJFZm)OhO%bx}hNRY)f2E)}dtvj!CV?g* z37a8n%fs5Zl-^{Hd?Oy1K z;DcWF##|weft-t}RPnR-6xokM$LSc(W66xP`?!evUD}jRIeudbbpQD5&`_0Yxh;@o z<%2$E`NxO*`(C%|>?IW+g)ZFJcb$8}l1`v$opFH%hMjd-={hyJvuDU`660fa_ftRMNkmBFFgQ093TrRXJ9?}Wrz+mN#C!!Y(l7CJ0O*p1 zCl4mCu)2Dh>8>vcW;NYNPM19A_+wJt``J$zgxHC*P7=oQt;9-yaE);r5wJLLU;l0W z2h}yhTeY3ZRk!)p8f=ZD*8NrSxIIs>yuX$mu9uTb?>=gSQY+p1lUu)lkBV)r{;fS; z{pqUSV#qz$Tq@0Xcvnu#eC741e2q_)Gw*SafS!UluE!N0Ye`(&Ie}7I?>oK7J0=|| z;x{qy@Y4PgDZGM6cqMV=-cjEv%|PG?^Y>8N%MJRi-%5pF^%)9)EW zMtyjuanFYfSm}w3njga0z;TReYedx%teH&$OHwWcZd8vX?oG~lV!RZw?2I&}m^MT2 zgswy8q_pfsvPs&I744w>8GkV(IYc2Eqf_`e%sGb*r&o^CYs>3mo1Hu|_M>uw zxB>yo_d@2zYu})C7~s(>clhK&SW&o7?VZbA{6yuA^kbG$1Rd_Iul3pilhWa9=ex%+ zsm^JeC+z{x?EF%s6nVVMIjVUNH`n--?FV$SyL%!&GJ*~4< z%~SNol6En00`?&jn1tIq%$#}^rG4=z(S;&{$F%|Q`n-+t=F!0w2{kJ>LTffS9~AOT za*B2ja4WSjXne7Ivhz^Z@5=vNKIUlY){qx@tk);*n-OZ*?3fdjISy;T_Hp)uX>~r< zdE>X3#Mo^Y4lnx(RlnF)9TvJ)2t1&GyYfvm`oTm}wQOQdQc_v%hXLYmG zq?f)SwKFC4vYVfM0b0PsKBjBh27zp}17)@giZ>RDlOMWfoy9y?dboEh4V%66m&};n z*ye};|DMpS#j}f9^XCLJ6xocKcdu%ZT=F%ZBtNK+Hb-R*N zh(~B)NS~pe=uAT5bHvrk?KbvX2t#`1`?;zGrb?;xokL1xqlgK=srd-CyS&{w0`BHUpHDF)#KcD7i($DLgaq>*#s?-XUjd?l|XV14f9y zqoEhlSl~jdH$lGhNHrHF{+OHX)=P|9_199mX<|Mf4eNc%e2-D&G&MJ-wU6L68G9~8 zjbM_h!Qdqg{*A#R#nIZ=mw&XchHaUsx3wFtY5e9RY>}BcRjbnWeg3lJuEV(O4}MKr z+`ps7+b?ndlNx={(fvH#ip8xgwZlN8cM@N5j!~k-&GH&IvC1;yBt?-QZF;jN-T$no&}<&;^1|<#ow<^2y$^MkI-!6e zAP^}ilsI`5{iVq9Rq3 zq;FfRx~-*1spgp1Iqftoc0fFgaPQVpVzV1A=5hTM1$_oDMHv}yfZeSy&5#7TQ~Wv% zC!f2!R9CQ&R`^-n%mANO0q0xO7WnAZKbjfFz*6Fmyb`YZdt}hFxBVhk1Wzl%;kYv+?a?2wgRI z6cNCMs}6x(yF=t8{qyp<(Cbkmn{^mIdw1gZ0^h?eI(4+^hMunx9f{XCdZVIvww#T- zY5AUhp9c9Aty(Fkei)i77j%&U*TB4c1(7A$#09q`FLjr*(F>e*|G!wCWR}}Ls$G>EP+Em5i+Pk#NT6p-~ULg$! zQVEtU_Jrd>K9q0Vu3L_2xkA8I{>trZ3|byq*hW_mn;LE3kU6kw-9AeZ_~!o-M*r?9 z4HeE19D&I38eMXG#8u-4o~3N5)zOk}KUrkQs|Cx$UvsrGm>bMaId}|2@u_uz{qOXTOgrz9*UG*X-py1#cRLjv73Ltygv)Ol*vYvn-zWIK}Jt;%i z>3f$But>-oZX%-4kZ&44j6RXRsE$$NsZO5J&=RF0V~kH!UrUZI88QiO-}z_Rm2 zKF=~RIq3NHba`K*)g%kw@+L4?0`BCN4}XOv%><)LnRIZ&9|V7|0%+z3ttD^&rkb4eqGDh+2Bt)e$(l2 z8VU>h{r>UA-eh0@CAM9Ldv&49iaKnlx3|NYv0seTRY^?t ziMi~_*q4K*ojIF)V&c{4-LdTU&Px-bco#0K_z1Z_Xe#p#kX?Jj1_B5>|a%cgH6=TIyCytj%ik`4cz!pZ_V`RmNQ*C$xl^S zC?smbBjX&3jkE{!%{+qoH=b zbgt1p>F7QAo7Yj++FaF@E-w``6`x~PLj zCJ*oI4jI(j1UJ2uCbu0=on-4;5lu~>KyBYMg4?xmqH7q;62C8WG zA8Ld)9Md+;RPFe+_@Z?+o*|S*ufUA*DLO<)Je5mLrorrFc$7TQu2#iH?@;56SYEt?fU5*n;=l6kzm6RrMZT zKJ38-;vHM4m2Po7|EiK#nGLt>b4izU${$_nPn&bbQ93uo%hQC*@7{2f5V8D;&5kRJ zX~j?sm)=d#`^hj|3DE+DveIgUosowdoO@Y-n&L|{<^cgT_HysxbBS#SK)AzFMbm}q(H5hjBp z;v5@02XM+DFTZzR&)G!11j&3eBhv|Bb5UW@KjMi0{l@*v&bjiM6wN(SHU8Nrq zkiOFK2}{XIuN-aa7YECa+jAl#o$VpiZnbT2wtVs-q`EMm;Uo^)Q5Vr)ARJaRxw4L&04`z8~-Hj7= zZnGRKRUyBs+7P~Qk-FhRu@jrN2tq|M`g&_;8Ax2de*9duyq^9g&+Yhy_a@(Eg{PBM zm8j;h?0j`>cCfEq-^xX87{8Xg>~)V#JO5F|`qhr5^!OKSS%fz+WEl+4{dH;b zzqpGULj`zKTWY+2SNuCHpboGO)>n@GFn|0-hHEfo?Lk7iLc;ST7G;O3Xr70i`nl>3 zYIhlQm%grFE^=OR*gf&fi$C}y3;w$O-K04VeNmTJb8H{%Z+u4UB+Z$x`w>i$8-q^C zQhU|?cz5u@PgPtyv>&s$Nh9%PLg7nqK0<_C1Ud-hdxi<7{Sz38XET_}H^k^o3@4I0-s1kh2 z-D!@oPrt>#11%5{V-6J{p5VF)sP;3H2!fZ%>)8&qmWLI`AsEg&X^N_SRUJ-_eq}u< zGWdKiGlXgci@Vw^hJdTI=HXmN3^y9>jT-Eo;VFFucc<{sWy^NB{Q{A<(O1XU1yWlKP{iIzkwKlTzxfy}0{YAI` z!CY&&P$HQM44I{t9ujl0&_j6ho(Io0 zqZ`k@siXC}Y`Uenk4!kuGisG>HBqCb=w9JMHA2pTYbZY6iLaz9iHfP|T`j!}EFyph^Lql=4^T%mT?jmYj!ztLIou5ZaDr?S98q<$`kmZ7R$Wn%-eE(!6uZ}Bx87wIM zG8)1~c3A^Nw|~K@x`_0mhI{<^&1U3h@Hsx+TOH>0X2JFM6fSPT2s{DDf(E#GBuS}C zcu;S_1M;!$ExU^h(c%VNhDHJ8)@KoNg%+dV>eHYvUIx^TutMAm4<#v0$Pfm?6H^)D zH!N4NF4DUn<$Zx0klg8D?adu7mE>E&+n~-Ow!y@rG1}zAI4iID0K_-APoI9x_&}tE zy1v9HLy54G6vXx*nZKk3N={!m*xYsp?$ap??$yG1bv11e2?FTyHGYi}%ZH=^kAcK{ zdhlau9&>8|g#T-49gaq5PZIbQ<6jQrF=~0p(Rtb6^Er7rgKXi?xl4TFsl_~v=!or#d%~*e+{qwTCuta1}J&+6jrOEl=z*o2U_Vkm~U0J1p zRu+n#@1r8!6{24!l%5@STuT3Fd#QP!|BYt7w)$0#TKq4)pYRp!I_4w`MFVDm8RKnNSc!)k#s1)ubsrm>Y1kgPCB{@GVpa&sAB zc%!aicr)F3N4degc`_{+qJz`SaLyx8h4+kUB zu49kYv--y7$pr`V3%iW8TbLtWT90a8o%N3nbrF5Pd7DjIs=yG>s0j8VTC=S`#T4HEn3&@_54o-GE^p;`+TZq^6Aj# z=dTCBr!PJu&VemYRf+nmsvLcsT>jYgY?y(Z8UEEIGM3cVW<(6Ku|bW8TCb16mVPsI zj-eL(Ge%~a*aQ?AY#LdB=@emp=E|%4$-EGuRo7y2SQ%Z@3-(q)fdQ83&>d zv&o`okxvJ6;>D#R4WVz#KRPG%3ar||-CZu#3s~B-($lL3{O}R1o z8L1g9)W4#^_U+dd`gm>49NnwsmY2HA@BUVI1it(2gZcNRh$IHid$e?B>W_}Ibb+E4 zKgV%!U2Hj7x%p!?zhzjvW54nCoJLU*OT>s-Y})qxu_LcgkU?t4RVMYA!03;n&LPjc zZ3WF!*wNKmAQWlE52{|4xx7a_sY#aajJ`2mQLmAUB6vla2xnUsz=q)d9ACe|G?D?= z)0T99zDc``?7+8AI^?3USkCEqS!2NOV|R1>M1BvnKof}I5G^RUV0rU(Z<(!G3HlZ^ z^AfhaMPZS`==M4XDTYSrQ&U`^i!$bqcx*A>2(a=ZyTMNd=FBN#JgKti4d?S`hhIA0&t zHGEAB6~6GlR^9zCpWqY{k!Zk#A=H~>l9Y$5qd(HEgd*47sd7CWxHy}Z$<&@KhBx)I zSGeI0)f#!KR`0C(R#NjvWqZ~ScG;S{3dT_B<}f$aN99CJ*nEUIF^*dsOma;FoY;ay)|XN*A1<~`i@LhI?e zC?$iV>9-5uWY^d_Krkc7{i zM(?f+eoT@Le6gJu-WB!tdsrNQ^rr3kvuT|c_-g}n_feDYKKiSfWDWCE1v4b zX8#;dBkjj+vP`E%Yu~@|g8rj0&1XSFAhQ>KU$TnC_?li=e@LspQ6!;1w@2*nmZHVy z&jSfUku%i(RuTQL^-$`R`fqHf^O?AV zcnPtX2UOR;vqMat#{KHiCvlO)$Rvg0ppIje9tFt-i+|^CpXwof{4M+agHmh52efEZ zOJifAvOgxAVkO{38G4lb#79c35;|RI z#>D(EkyT575M2Z74;)>tU>_luvB(8nG6Q-m-sK=eNX0tdjYP2!dy%2LI!X%GfXxr> zQP<2uUI2=t0kELPLRVa}a9^>7&W-2aO;9X20>jyHd63d=4S_4W+w6H{8kT7~F`#zp4n1SE4F+|@cM z2DE|R_D}ISLY9sD{}L_z6MT5>iuXu57xnrIPXW&KhLE)M39iS*;H!Fbm`0 zt6P1$xo=uGpb}P~jL(x~KCBGZkWg{sn~m~c!QXM@UZsR28^V6_>{q;Xgws~KU*2^H zrjGKjgh`Tl0QXq`3$6Mx1BZE+M1qF!!@2ASMI+jkFM@w%YZrX0HCr38pD(Z+e}R** zyK>5xP<=8oHC13(e`{yB)F#}s8)PBl-SeF~G%t(IhnL-oEyw3u%tnjXLXL;4Tzwf6 zpv@FM>F$avr#UQmHfYJRSm3fx$I@%o9bb7js(YgHFjn!U=26~At?z#%j$jEzPmgv& z^z3-)*9@65LzWv27Ap5AZ0&Gy<$M=;?As%`H*{YtrFMbsJzAYw$~)fe;7K|VI9*`D zdQ7x(PxdY>ejFrqboHVr(W8XHuJ+SVj$8VihQPut*O$4bWRI?ycJMT--15DR&z3ak zfASn1Ott_cb!`hvYkH-r>f^TxtEIo``n zAOf$P?SsO{(;W#C61i^q?}6aAMkA{Am*G7Qp{R>;b6EMN^x5{4axs;%&~#4yOw1`8*{iCuQcJ>?JJlce1KqayRsX+X5a4s*0T%6A-BxykG(AuhS zSBL-(!N1EB585}ITcLCo0`#SoPx_e`V0f-g7sP+ne~0Od4n3i8ZZ6*_t6VQ*QI4n4 znoRfwXH2Pm4dtZ%54jveICM_fOx>5iW_FoGKAb1x)*Jg|TZ?^B5UoB>_9pSLH!{i5 zs~iza;E|F56xwimwD8GnX243b%~9nm!;lY(kZctmZH(5;xQM?jL2<-wSU{561VnKe zb_G_8JU#{b=sAzu6kUTeEG$t~;P$xCv==Ev!3Z_?GT5$~ij8g`-`U@YAJ-bEj!{EH`NbZD~eek5;QBnS4 zKFlZQ8~J=w04He6F$J<%LiRAy2A2e!+eSp{a^Ye#<(l0%cd|ES^_hslBtXw~cVPaA zGE{325}GcQgrYs~uK1SWlF}FfsJHG79EJ^S88LWy_#<$iO~;!2rv_K>JKB9H$>V|D zqLmP1Ri#}WRn9*5EL?UifSg?z4C8^r;xZgkGyVb-PW7KWy_=ELA`aO)=TK*>0cUI{ zpT>Gf?{$swR?y~sTl3nsYu%!=1|Nhqk3mqzmF31q1(!*cTB4*!je|`Nh{Hr+h@35- zl5-obBd!2&c~%qwo(CbaKR@Y|yjUI)*&nl29;JEDtRfM+#e2&1xt_oL8xIq=TReXfKKs;qNro|RL|Pn6ueg;4 z@@X)rNJf5+`rT+oM!n9jJ~{X}UO0Sr--UYMH8*AwPVB|YzCaz8xk^reHeW22p$%m% zf3T#?pgIBh3-9`4->@zuQ^jGtjq5;WF_%ZBCV1j9b^A zSN*Ew4kDl^cY$8FOx4s_KoY-kT3O-V+^hzmTWrDS%(-)6LAuHuHk2n0ciwBEZXH&D zfADt=I9Q~-#e`UJ=dH(<74_mk?Y+AK3tNDpI#FzWil?AM-j2X7TJLni{NODah zTvqN0y+sg+KRkY78Tl4;3AbC&2kyBM0JxlI?7J28YA2FarEY2Pa{^{4^Jc&{2;Sd| z(y(jgCN71$y6NPi__T7-Ndd8-hP>oKAbO+avMVq`t68@#>SH{uh2Wva!7)PBOC;J< zb6CUMOk*5ktRkjiMdh}>NwW>`r4ZsAA&p>6Km$ZZH$H^6JBNJU#!Uk2NVP{1#s z>R_3D_UyUcX%bUFrsEF~^wUGK7ax9aZSG3B9X_^`r1%<#oFrV=Y1&(5*OQNeThBa) zzzDM}4aC6TzE8h6o^Y8Oz?tKj-^@R2*Z@pcER3A*cPb0x7l%aCjQ-j{Owm<&DyN>B z8&VBut1Y!M(R0ElV4HRlkka2GA(o5Mf3t9r+|Y(jnmsmIukwYdqV@Txn93Ks&v{G* z&xJ}02Pu8E$1+KbjU}Go_%6KZc;=T|t=5nbg>vqNc^7Is3x{}+!nq=j48*5i@J?Rj zeDIy3Q5Q#I`X^-jL~pP(9RT_3NCojGYnHUbn+3ObH|R$2M4Eo;uo&a!Aw^Trxs0_H zmS8z{A#~l%_Mu_LlAKZ5ZewF1IQXZ%j#QQ`Sb7`xAtb~Hi%S&&>-^@ zOuE!f+P#y1m^bRJ`T}7)I%4$S{3n_7 zJH~BV`B5!#V6hbQ#Qj{{kj7*64n4aF9wG-Fq#UtO~nwu1{d4A?GPS;;8P*!VU z#FJQy^%c9{eT6CQlc?C^G`z57$UjoEzz60&3~_;61zPvW@a(iHmUYOB&4zC7}g zb@l2A;{!|bs)HB!4j-PS{w9lVgpqa+0dNP6)r1b>OqwocNG)(4zFGkL!7yp;0m4`N=#(o}vLo0Z$t6`xyl zt3mB`YqcfTTEjW&;d<^)fgK@`MH$eZ!pqXYnOTp1ZwfyFRS= z+V+m~S;*}S20tA0n~#~wOudY6H-azt2woE(q0(l28}EC(im9dPLJAT9lihPX|RpMb>AZb3`5Rv~KR)9Hhw!^>KkkauY|F)}yb z@qU(%2*24Ye-i_90~y7>scI7Q^0e=`Jkp<69H=X$Ni^Z=*g5XAKyx%+k0Z^F-qdi{ z&u{M(utIb^g!+6kusz_4vks$g(MM?ZD0hMaq|H)0s{JKi(8~H*Ok)%hV8r6tZ7c2Sm zqqvcwo5G1TlBY0#=xf_}fL`k=P&JKh$(1In%iMvT27+eC?%*DUJVJ(m zM#_^An3{PduR z#)X>UrMEtFnSIGB%l4WYJa}8e<3JXz z+AX{4k__r@ctz~i$A)v;!&#K4#v(Vc=5Z~m@wcPjlc$*Ll8acYd zahRL4|42Gt<+BF*o-V^~JZK?)6hu}8{m)59g)-NNa+ z+XG$iT86($06vf``{x4I{(m@XUSVB8rf_U$tWrztNbOfo{NrqmQa0dL{&RkdGKI4L+tDg6Owt{ zLr}=@9U{Ucvpp-I0dg9-W6Gjki!U7-`*OQ=UbIh+!w&7=w7y7!AsRPiN>IVWZLF zot>*1fQMSn@<*+2N=?>0Te-#p%+8gMh&UX{zu8EeTHF*tC2NN`U?@5E9XB zjbKO-d0pl4dQe*@DkB5PkYj_65`@K$^TppHT}7+$0Rf=h(mT6BiVltyzMuVCve=jW zCA|@Pov)0kHiys}+t+@t_k0~>Eb5_jk8xNw_pcb zK=^gwVH5C0ZijrxvRa8R_@9icVpM>5oi3|)YqEXP6B z+Nj3F4r#~SC@O7(-B{fXa74D993KXok@2j7o}yak_J%$bxIUcHu6}1Qq|b~tkk&)T zwj>JiW@(~zE-A*aFc94Stm|wyH$n4o|;BZ}g&2!&;3;+ysHW=2S7{gw zSapiCuMQ?kVYc37Q!?$f0X)JogU;zp?s*2ON4>{zt`-oNjc#GFYYBRn8Z7pFuyl68 zBuLx#sdhgr0!xhB$S~+_eMjotXC~Eo1F-GtKT1>6^QQ zRxSPu<~c&1U%1N8!SwZhO})}X-V>IZ3j%6viM$y6O$5fQJ+FH9NVC~hRS2I3L4S&s_Tx2pf2U#!E`q^1rTW7m@(VH z;m)Y5TeJm%hKdGk7Ps(87I#&SjOie?3f&ZV8%)y8%2*9ua@a!?tiA^8a*S zmY~Z+3WYibmoEJqwzT0gx>(bEa6MFHpPLbsR2VUZ!6AQYY3g}3}1zeg@(ogHA{Dkf7Oa9pLR zoYb~PpM9>J|8@#c(~)l+9DIToAKHcKFtd&+s59yNcyXX*VH+(HE8?J&9a1VCMHJLL z3!VIp2-JmaL1{3yRX@3bq!xc!s&fuQcofot(bw*>QO-FknUE)siV#PWTVD_ zG(u^d4mNn7{xIUg+z49`w@Fl zN>k#0t#76mz6>v1DXv#ejdShAzzM)h9RaXimi&&^!5v z>BKPMh~#&M$n>#$_om?McTjW4C%=3h=P>eQ-TnyXgW$2veEqCH$*Bepy@s##(Pj0w zK};1z=0mx)>IGsShEU-kuo6Tgd+&(>nSd114w|?8h%T%TG3NC}2|ygoi*D-jIc*v+ z?nw!O=U-4}nNdkDSM{@=SpT{Ca{8<252kS6{TONOYyVBZ8;;%4b^B3PdG2n5AAA*I}>Wv8KW^c%kD<@9F7 zQ|n3#k}Ql*%bLg-(I~GI#fTsnidWCo8qe0gRzlbouy|&4b6XWu!uvlH+)tFg7yVE^hqXOeIG%x5?pvTZS^Vw?~;gn%vksMpMt~D9g`^eW3y5lxRZk~iI>}n zi0oZ0&oz|zH;46eh(p0pZ?w2D69zYTDM7H7HOGU(49`Yar{>q#Esz;?=>(2TG>rYi3p>LzmmV~Wj#GJ! zR`5(~e@*eGfyc}8w3h`Jd#V?Uj$n!b(^>2B-A42Eyk({f*@<+eV_-zzeUD@Y z*N-nZBku87E@kjZau>SHD95NJYn$2u-7)sN+x(;!znlRyDD*5NBqq7QkG0U4B#b4E z%YNOc0k=c~p@Bt3(KKz?q!;8Hb=L4BPug>_OfUa8J2h&Ab!Tr=%Y#SwDgp5YN&*u8 z4;(UQpn6KhP2PwFVMU^=LhWG4z&;MJ}KbN+b6 zvk0Ho9+^hK7o!wa?ebWpgR(ww08HZM(i>AEtj7KpUYBD;94+tdcREs=OuI6@o3{Ea zcIqh|^MR{@^JmV`q(I?#sm*MZ;mo~mQ@`U!N)3^D2cBC9BDZeA)_Bjk=fobTII?Lc zb&EpD5h*X)$0I9pGsmk6O%{HY93MK+(s=&bIXXOz2;gG*<#6d^&O_!22goVj8?mSE zFkMn*x&6&B?&;f{zRQ6%&w3OjQ)g4Ju9@C@a?wlSbb(weZn0K>)|zkD>@jhSZB{8ZhregKh0u=tX>b&Z&C`O0FBNZ-)BY%@tgptAK+}6}Y@> zgHbX?9NbCugHdiz&IW5Iv}}{6*gXWIw?FM=E7ew^R3IT-`8}RzxSZ$5f}Z6&x0#u zq_SzJ>(erZ%}(Gez~3R69ETe%%Wl5>6tyKl3$@hb*2BO-=AXv^{o>o=ySo%k^-St6 zx^}3&B!(jPqn$iUYGsv7Q6{#Df{QlaFIL~vg~wMAfoMf6-ycEYnYrt=3a}Ir4vox_ z7X=2`BvOrkP;DuUp~^NBHRC@*u6}>zHq3yy^+W?g?pn_uA#2_FjegotT(2fJO-SMX z4_d+`JS@(!TaiXYcr@ZzNNW$CEAL)mJ!95Ne}y2p?5U2W?MsZkcb)F+$#|d~%Ss5e zxC%v<;}4z?3&O|?dj*%B0O9vUYH0_%E9sy|cq*f(f(^-D6lg^v4~pVD=UqI3`frh> z-%V@rDPJW!kfhUr;U-E#+DpSyVLJ63ijMavtQvUCM*_*2pA>G@!`P?nz8u`-E*qXh zGW+NMkLX#^!&JbUhsO-DCktw!j26;u}^H8OIC8 zvDP4}Yl=VF{WK85IMd3AF5h>DxSJajwN$oV7p46YKm+nV;}YUyO>!@=J*IU;zg+TM z_ALh>n#Pvnqh8g1LinLC`Z~Wn`^PSkQcg|>_d1CcWKTnXQM+&EOsRNR5}5-wPj zKhWx8WWvZ8%WfzBf2;cSKw{{Yo=>W)sDFP(&(0FYJsm(Jfx(>&cjD8YM5!MW2zeIA zN89L+x&Q#SjDMXT%w!n=zSgtr7NslN0ufVJ}W1}&M4jJhxinfR*DWkX3=6-zx@UFz(T{Qa!! zOB%wgvNAnZ@vWV%E|Hb(Ofjz!@q-D&{MX746NZ&ao!9?*6w&|Oz*BjQJTMnS8@Av{ z&=K4=ZmT&77Zouau%?nQ%y3GxrcQw?Pyk`V6>x|3L65q&ZE!#IFrEOCIdR0IS@m= zfiMDWwe!(@rg-ge&y=^p|KFj)SYMG3XAUmTTmeon_b4DE&;__tefw)Ib3US$FutHC z^EEnrbANU&^srXqI(jh7{xpsoRpFEVrph%}noQ3N%!;hQ+ z#0ut>p9~k6?Q&DEmO5_4M{N(e0ne~2@jRD^g(f-bhj{w_jSm%gxYkWURLL9X_a9vb zQ11+(#n-gd~+iYwnb|gqpvB1r?Ck+tVW5GG2S9d{r{}+W6>qlTR_}yL-H~q(N)N|F&``Kt+dcYa(TQ!J1|Wn z*YbhT=Yx{OAG)F3Sx6oDK+%9DVlYg?a+l^kudz?X@!sgW%2R#s#5CY~9t<3wR>nWG zzTuHHE<&yvFY5fLJ6`OIHOLfFLFvN>S%IN#h>#j@_>N3;Gcj(*2T#I+uCBV$B)tKV zjx(;ZqBaMpL0GK|JuX~0%R(*K?nMSv>V>a2>Eg=8tp#v}F};lDi-{!ev)N5<_}DDwN&g5W%(B1K>RyvOL_bau(js43_PVu`7Nmi>Otk0G4KpCmnKl;E4)9> zuM-$IT#_fYgW+u>XCHj9j{A2<{Qp^TCn4ZC){UE1x1K4iq9RrPVt`FE_2X(k}VYw{o$tXOl}Cs_nb@1bgSJS!Y;WYU^&*)oZz~r9T$)WWb?sG#LJfu zjLZjo+w$i@i5{9DXmHAl`N}!#GJ7CJBBYmXtY) zCX~4;gmxv&66bK*PZg$CsmTBbp0?3gSzk2cde^EJT%h-_t8pkHl}P(oJS5hwzg06+Vei>1K(eN@`ekXhG#um8ohR-hH+E~VpX~HDl#0Ucrd4f@Jfr(Sa6Np$Z$6Z`e>3Q8GubS~EOe?r(LL~k!s*`EBe=I|X5dVV0IeoD z2YDX=H7U@fuFo^f2)P}V-c2@k@QDOr$U7K$8Dp9XvQoN^j1 zc$!n|tU&1$Qsd2y*KD|#8AuuBLLwm6GD?Z-KT;(B+;IQ0+rWw`6ejt~5aV(ne5Pe3 zTy5lIUXKV-z|Hw%=-NS6+6w3Sk1BA+wt~v!#{t~eg_MmcEzV%CH33N~UDA0nQJWdB z(O|6Oz~{^$8=7&jmij`Uat5Uq=vZSo2W6Nw{5zMD?_F^^Ryx@1@~&Ph(mD>)obp=^ z*YK7~inwz!yl9UXv6=<+gnFPy|#a&GStumNJ$MKU4nu% zLrH^lNSB1-7^D&dNOuTINrNCQAV>@#AgCZ9ARQtl-F@!S=h=I|@7cfe`u=nNVU26K zVCFk_T-T?r-{alBx&zNg6?O>3`gSgV zjF)h8_5yPy9_bWY;Rawvo^_TIGzwmpQ&Z`DP4f7vS(QT+(7bh+UzWn*RY?k}mgNA) z7Ax@gil(BUsZil4U!=rHg>J#Wua?h(`I8dj558Z~R0+f*`3}NKAKiQS^iK~Q75Fl6 zHNO8!qlL?v>j1Jc0lpDGZTI41`HVc{n_?co*Kx3XqAuCOwIAX%l&YkVwT?4nO%P$FUnf2NQ@u0MUp~y2^)vLLIh! zcTcF{yZ3B>1){EA;hm1R)ttI&+d+&OVkN*-X!Tz5s~#xBare9bj85?rF6fQSTzzVDaic7cq z>$Xh{02SL*N>y3`v=!7PWA(Igp0p-aCcpoOHR8QaYfM z9)!?%V9D|pCbK96Td{*^TV60@C;llB4TiJ{WYm5xoc9C}Qt`Y-0zSmU3GC|B9^$!7 zU@Bbve?CMDNxsuS{#sA_n>i+ppb)EhP?Is0C_IpqlYr;q3ZwL8RF(q4ySX_N|ECKl zn+$VP`7HgusdT|jl?RGH)0bABJIi0yz4v}PiQ12pTlY8lfN%;RPvitW(ZhR^=5LWR z0n`O8@3Es9Igq4jM4QuSZ2ovOg${u1O=|8uw;=lz10MU%>X;Fbd(Jcsf~~ZCQN=+S ztefdXNr84Dat#=JJ!3<@_)(A;Fc)%&QYe|$~A?gDOq zc3G_;fU;osQ>b@5E-zBzL+bT1#83YVdUem=?eweT#ZV_8;5OOdbb)w8_q79<={xXy z))VTIn2cQoX%>%-EPFus%ueY-rk@2kCOI(4(~}@bpo&4#T?1oORh#Cbxay#xH_-_G zcWuZqq3)br&ofQt7(TQbxY+E#vR`v)mo{azZB+o^<-_k?EGlh57i4p-)MDYN$(pkR zpm3}%{tRyMoi$n^n=bHA0Zf2A=HseGto-S%2$TNxag~36T%@-y__$n{A+ctS-d|t# zaod!X(u*PIyQw~;2;~5YL+vLuASSavIoNd2)5?(I z22pOqW8~b%ULcvx4y31=?##)&HAQi8d901EnX7o%_}JkwBKrAXeOB{T>IGHy9`RkDrRqucxB^`>Di+ zuLUH5Eu#vk|J_w^*B`8JLHQIG2sx~yHNFXnu zc*Dp)w}|6pCxH%kk+m>n`w56eg3J(SJATRD0!Z_c+exSjnsD?Un4t_rj5lXHV`4OY zw2=vzN_Db>tZtwsuj7)W`_I^6i^rcDc`cisY=Kf(4a}9CI4-A%x5B+s>F|k1r7Fl1 zRCv3Gpb*R(u2=^_)v+KcLK$q|r<#m_|BnT?{|n7@R>GoejhE*$`g4pbGj>!0B6mh^ z=4)rJnMeMgg`j^ubsbD-TdCIX0xa_BqJXi`{?a`Y5a$XE7ty&YajiA9?eM3BRPn%k z;mx2-NEfOCgo)QzU)Ydl-vb-4xhc1c91q=vV;&&wZdl ze#GdZ;EcP3g7cvY7lGzsNC7mzl+T6j!#iapN<62h|0K2O7g(i81gN#T`Y%P(?M^;d zUu4*U89l_v&x?S8Zx1DJgIAc#FKz%0{?d*}4X&A_?fdXNbJ0nkee)O_jE@A9>qlfO zMLKc^3@p|M11cGQcC`U$&dGJ-9JELQjMpV&TNgp|GwjOd!2AqsG8Ftf&=ty*5f%A^ z3=TEZaj_MF#>om>W&Sht4R`=mVygJV{Fl@!NTjd7JKEv#j2y zUi=?&(*LP``u%!(lM{gsFD6ST)Jhw=uCSy^QPSde8N^h@3z$n3Sv1xIz0xgUyT5%C zr#em6uSx6i!;NXV)ifI(G0+*MI*n=jQN&ct+^cvpPVJoPc1Vzt(1_V$^@{Iov z!*u>yX^5VcIz!E*Idh9kC;;GWT$K zkURZ|1H?8^fI+?fD>?8{#EaCcw?8}XCxUql!LOPUTF=_qkfIph3^g{|3)dzCF$Zu} z*MesNr##jexajzp%L6np;*S<%*D{IegjHtS!xZ$=?fLLLDF3U`zFM7O2l}6f1CF$;toN0FeeeIr7jc>xyo2mF-hl`C&!YhV8#(3_ zY)}KIM=pRE*6x=6^N;@@UL*9uWVG<&>uJCCSRh!d3IXc>7vC2gwFvBgR7SlLzhD?V z5C{8zabNy=%wLi)r+bFm@Aev)N{I>nFWOJghi3EsdcG^apRd)g=lkd14tPL{Mwm0y zd+7HW+QMZlbh+-&Q995+gqHqVC?&raioXtKp{V}xG6PQmW)74tv&$yG3x^9iR_|k{ z|G0|AwBqo3Ac%kewd(hZ#{!d^SkUZ$yk9U|<`Ed{9$tL=_hE|-@+$s+SO~CP=y<)4 zk$~k}0lzb^TW3edO~4$ z^!E;Hu~ovnQ;*;8^ajZvy;Bh=?d>FRN2$#Q33gpoD{TE|&x*VOej$zQuVtZmUH4Dc zRh;(p&}u{Fd5ew@%LfY$#R|3W9t357-LCK#zgqRU-wS0^@lO_tq72qtX}~qare!G2 zAx%5)-|rCR2x4cbjFf-x{tn7N`oxDg?(5QgeiO&H7E37(eb9gZzEF!_`>5deJ~IBJ zeH5o{)f4T*dVb1Tq|1?1zA*WV}KyFWSkP+g!ZQ~SP;^^Y~>{p;{4`@N=Sl>Tr{ z0nV_Q>>pp&mjuj3;QOD8K=4m40{9kqm|zXH{9_Ge{8~fBzt@oGAFd&cF?t=_0_(p= zCSvC@<~X$ceH^w${4d8L3P`GqhFp(Kd?np)Jh1QOQ?l$);`)2h{q{bV}4XpP?jsC)0|M`M7 z1N?UK-jv&a?MP4IU#rXGKdbBSi-a|!1xOU;LFv8)5~x*xkAxkt8RUNboIAvJjP=oe z8`D{Q2d=Nv#QnI9-%EX~QhBS38@Rea`rj4zzi=<1n3Aj=XCdIr5oQsz&-g!b$0RwfOo7qyH8wfy5vCr&gycfwbH(LQP>O`_}mOL&7 zV10ai-xPhk2QmpGpNuxZ)I4PN!?Y}Ocf4=y&|JhHBE7|{kCgqrwuhh1rSU7N$ulH;J z;Gh7<`2iUA*qD8G;sDLjMy&?7zr!0`{t%$|y01+|06SNMp;Y2QXg|yuN0#ARKwOuM zACeWAz5z9oYlRNLfIQ!t{hHfjm@TKw0Bi-O#_<8%A^|Asx&TGs)&NwkNm5>0D$Ty9 zxgcCW8Zp`IX9*~PymSK%1=z4$IRvGkPh#2Y4J`ZBeRw~WW0)B81kQEJqe}y*{78B8 z6Vxwu^Ow!{x{CfgB^e3_XM-ZlSoE)gd7~XuR*T^LE*9D1z@}hz^qT7+0kOMblUyLi z%&sdIUFGkx0*1U)`W(AA?g9<03TTn<^hj$-0??ipxz+qIA(w1uAEIpvqrsc*tYy1d{JD zlq&F7$xd)E_-^<%0!Z(e0W5Nw$lh{1MMlsle)$R}tbneKYyI8KiizLBM%HDOSVY4& z408Ymc`_REk?XF5C!z(c$Y)+~$;4u{xB&6px^v~XL_Q66uhWiCEAO7ShwYYBJDZN} z(R%!ASsBC(N{ZGP_FJrf?~z|6>9yYvDqSAzaq{Y*o7h^=P{1>*17~jlsplzX3Ct&G z%Y1nZ&Jz#S45_z9VV(B?Gw~cl#DB@y`e#|I2bDC=wXjQOpD~Hh)gVj; z{I?nyyN4ss9klPY`}u+&Vh}7Xsa4XzMA zCD8xlV@!bCBVfKR5x)1q`7Ltu)w?wAbp2vuKE*F;H6NGU#OZg&mHw`@{yrLD(cnlX zP)@lM{x>kphdHIuzagC+0&smP_HC(Bm%vv)0IJUppiTYnydG*_DSeIT9x)^iNe4|A zc6(=a7)S$D2LgI_`|d=5-q3}Svjrb(welEaD8*cWj{t(Xd>_Fp48OqB@())R7-=Pkgv5OE~leY#h!cuZTqB z`F?}zYkYJAUGz!niwhS3SmPjmeoVE$x#Og(^2s?*DnoW(?%#W@Ak5d_O2tY4kLg!jG|VV-Pr3T2azpctLs^&+57Le+-2o+J5B!%a-c`ikfNw* zT1r{p6RY>w+;ipxnIH$5PBP6jEa z#;I;)U|7xA)@tpP8b*@=Y$$7|Stcl#9xxUp59IDw0xZ$LSYnAjvKHh>n4U9l5SZh( zZIy&PBIVqtcyarBNr^(+%6*!W1NGxM_3AC%BYoW*wGUUK3Sv9(V*b^yi-RdDkR0Rx zsMKfVFtzXd*dR(4w7cn4HFG+^mcBC7J!l6dj-%%^hw&KYNT0;^<}(uxu(xXDW}rC=as-9_$;jJZ4uC7{+a9 zWmvoB-c%_aSi`Ddif!S(c^y&?eg)AY!UiroZ${acIuMxQO`nF0f%OIjzpb^h12>1T z`_Eqv_vvwfDJKC@O*c4YCwx|#F1&$cOsDag2&Y3jyg!z-p&Ul5oxCYQXt~n*Hrs2$ zGfZ%-`%+iBWG%qGEb2Yh*NmmGXWD=q^wAig?qe~NKR{4}0r2^Ww6+K!ZzW>C4Ap7aB@)5ot~rzZwbf-Z6w zx=pZInV}5G6F~J0DXh6&eICOI=J=F#RyiIq$TJ|X?=P4BG)0X#g3BC>K=j>;4LvKC zSxDK11)wVWMyL&s_XN)Q>WvSjGB})=fg99TUd-npdVN=V#tSWtD7T?W;xXvgV)Z0# zKA0tHU>&W%k&lj5ya;zZ{tkprfR96f|kb|7W_+b}wy-5EfaICkReQTKq%k}k}Y zQdm#RCU!zK(hfC1)DDFyz?D9LEnmhcP_ht=jmf2XlY@(c4ZCtcD3+@K+SGtZp{3pR z4z4a<05cMz40Qt9oSbx63yD0fSfklH&o4f7`OF_l8^SEA=Yzx*31QK+71&42TFg+f zF+kW3r&$O_S+x{dy}w52L7SV1TEyDN+9wiogikkZ0=;`mHkP1H0ic<@rLyme`2)e4 zInkXPk+dv86*zqN+XQomGa!O+mP{-juZpZ95`jujx&yxE552&O70L~VY-lY2pST!S zhbl0D;LIjsDuXFzhA3jox@_8zGz<+4e}Lj5B<>>%1qK4AeU`?VIs9?2jkFeW+}<#c z1Bw#R+UUw2t8ii`vx)J05hcDG+nQ|hJ*}r6Lil5KUbch~t$UB)*^05UY&=Pq^3rWM z1`^vF>)XDTEx34VVo+#!UuT_%+4UefPdsu=kKj;v-AoqVjtOd>3W}Jpk%#%6KVEocfhtcY8Pm5nRZ!vv z<&OcTW+5DO7*PFbOBRNP!+o8AK6$tNRwyXbGjUMiBJj$9 z5!}{`IPE5|CeN$>DY)4@d~vTn7`32qi)`WDtq&bES?rfb2n>PTDmRlqcLTj^C3%^X zW+20j08OuZNp2)~Vc;|F2h+)6iSdBGVOmO+c{>`Q?))kr4Z1yoAX}J^3HB_>4qcF^ z3YNV+A&Mx7n|4g4*N$=9v27fGNa)9@cr(sc$uZ>}#1&jLF=UnmdauN~UAK-jx2tvc zMN3M{8@bSh2&ze5deJ+5wnXbL>lal~=*<8j-x=eB15 z+WaU=fl`+>L#P*k9vhypk;5KWIw9xC?;kf`JzCpvLeMnYkCB=*`{sbEN}b1O=cL~b zK~fF_fO9nm24Q*pIw8t>=|0d*i02eroZ)U<+a6Z;bkrYE_!muGBI0O`)mH{BDz!azsc`=Dxi?goRFl7+T z-_8M>qFtch5jkRTn;x+nz6Xd(_lC-ms5ykO2rxCV*QEf-$xG1=-Vo!H7Gdhmty>HMjLoTYO__&5!}35UzjYcjRr zQh{>$!%E%L-Iu+HoY~#27xLQE`%{7+7Axky_ui_V=Qe2kk@_|7&Qg5VJuU`kP_$V# zrR}~sdEKNK-W4m8UVV>d$NopTsGs_$0tSgk##-qzY{_M229}{(`RNJ|?DjK1PYfTr zPBpQ~mYQ%XC(i9f(O;jhd9mN}M05FCQLav7J&$C|)uvZ6)W?!;D@8JbGQ9jUCkVBR zycB&R9&1DD+wY}1gM`>1uPDC~F^YMr4Bb0X{WNBlM86#ILA$v3i{54JtfrS{@`T@s zWCURr`&GyAtLTzFCvh+dS2r zBZmZa3Qs#s%bQh`+D@b?2qep&smyVC-Pd40GU9&Vd}dS869!3NRksA$KnUUo<2jz$ z`@+M)WgoP2dsTrpi(2%1PqMZom)*ydFQs8u))ii40GP>%lSX59R0(chG7hg65)hqS zMcd=C(&MD0ggM^{x%P$rQV5mb$JfIbvZ4(1DJ%f+v(VYtY62mp1RzT?DB5j}5Z3K* z7;}q#tIYk#w%CP+;Ht43laL1lR=|Crd(7$l3 z;vtYiyea*;pse!lWWrBnT}ZZoWM9FWs#w3wOjmq(I_JVa(fz-WK7<{)XJSCo8Fol_ znKSF4VkyJ7JiQQihpXNEoSNySX_TUG%Zia%kgfYY)PQW2I>Ui~UmWl&Z^>iQDUug= zy1Cw{^0p!>RkSCUdms_Pvsg8KAlRN9Hg_X<$#lbOOG3oG`MIR_+W1@glBg4}ZIuBI zlCKMSdzKr-hH|S1tBXQ(S3FbHbeJy}ZVC((%9$74EREb>>vVWLb%rgB>_V!hU9N~^ z61^!teB8@rVm!och|y%qwTH6El%MX-UV^>sSu?%-rk$SR#OQSSb?n!BmuPc2%h zupwk8@&WX6z{Z{joZ;T`4B7Yvs7Jgh-#dfe35GmFbn}^1Di&9~&aI9}SLY6$>0b5DV@2j6) zNcuc04N%qLfZE8_&xg%mFsK< z<&a7;Caw_gdFVu=k*qQ#v}oY=hvhdbk8U?^sflu%E{do=`>7=EKAX#D_*u@AG;S|n z@7NtHjBF$CYfYxxv*P8dpM1}+en`EipOE{ivBBd!9z>(`Oi22vqKCIs$Hne5k74y| z6D&{j>6WcE=(1ld@)5mrq*}9jU@~n=`NCU=d=UIO`{kYI0fKYQ237UnJ$I?X@_vYy z6$^S;G(Ir!F%fxc_bFanUXXF!>|-%2HUFVjlVydPJe#qcw3Kp6LCZTrhTt1dWVBe- z%PVacdWE;PH&0Gc(O&W(z1FBx@;4>S49GD-ppKsg@gi}c1zk~f7`W~bRxO}HAGjlz zk}HBz)zg4_06krV$;UH9GEld{Lhax(bwEaH)jA~!fEfuO;ED(4kCo*n9H^KII{e1v z4sKk^?PNhq(own=Z9kjh*jMzovw(gOC>m{_P}kP0M>0v97K{r$#lCM$AczLGG=uLz z2G=K`f>hW7L+OA#pZg)aNr?37WanzlCc{N;UB!YNlpZwW?e`C>U*Cw8rCZ_33+GBz z-gwDMn5(?ppBDkLP^WCoroWW{F*LO!Xl^D(t1UX+Ufu%|>e(iaaoG-65YoXcYu&pso3((*d2_cK%0 zhm}a(WX(H_818yIh3!`hCxL-unJw(!WS_9NI81M>a6KB72n;(SQ@P$g{1F?yI(9?A z|H9|!FhA&p@QXaaO@-$p%7_r5_KJkUsXm-MjbJk`huF~!CS%L4_8CwV_L_4>lBBk& z)n@q2*HIMc8+^XF-nd;8(zA(fQ_d}WkR;b41eGs|ovl_&2z(^2a(LB!5b1SxcuIQNFK{8}lY&>MZg+b;)X{ZMLnw&SRPGAJwxi@>4 z%OP?8g8MXth5=mkAxQugtAC2)e$u#~rM5ssnf@j03Cp)nk1D zw3JuC7Nh(HYbr}AkeUPHm2me7X%YY^%bnyUoH^`5Rf0%E2@PS?bp2d6U;1lFV`pir zV0AK!VWZP=SgCSLP*+%_SYhN(ccozYGbbuEPj|FBtnpkC$RRl%ZkLv)1p86`{v^Gx z4GL|kSPmZ;6YQ9Dia#L`ou;r&IaDVJ1`wN51oAcVnOt~oEM#HJcNnL2lv~hpKV~KB z4w|r?hp71$XF;zyF+t|B*n1w=i)JkUZ!`E5HB--_g<^ou=TXsNg?QF zw=Wnh*#sUn!>G;;($yd=FG+l_I$NxUKKj$xLSc}ZX~um`f*3@mCc zLX`0>Aqk*u+yi3xl?*Y9`bR}t);7^5$uSXD>;|v&O}VTa-K!?wxEQTILfwTDmGs&G z*|D1?hOUHQ1(hLQJ-75dR05EBuDRlqr0vO=uqUBO@CO00EZ3}qbYS-Iy9s6G;W5+5 z0GDf+8_gjNXG;TIKJEN<7UK(rIypPrs1;WLx<)cAS#FZjOXQA@mRkp6II8&;b3qV% zE?m3nNf#R0I;A?j6Htx7_^f6bLnS`ewb2Hp65`xxYQGIoRQc_aISC&E(OBOuvHFp` zl}(QrF+(7&HUY!pfr(3gUrA+tH9DkSLdp|!3~1Yd@D0NXdh&l2C4y@vl%iz3?!F!= zB?5p;U3IdZx^4{UUw2wW)JuQU(bw5a4VGvZag9R9Th=@{;Zd zHs9fWhG2e~-h5WiDP$wVZ@_D$Pptb^aPy^FPo-0jq1OSt3s=TB;Gmm5_u277H4nLA zLK)ui!)?LJ`vZn=pX`R8y<@Pf*UY%~d^T2RwSi|rf68NR}#Q>t=6o z{34$v{WOwRHiOIrXQodOsIaWmV0oynYZKx|msb*;wup|A~Nl$s~Rj z!Cd*QV)=vBWf7Bx`w8GG(Jw8Wq6Y~>abgv7D&gNhBqsQ=aGSgUD6deEzQWf=se!EXNpy1+?2NZc*>j6nvPdRUGyHB^;B|L>YmZp2A|zg zlC926zpkET^#_LD>%nDB8EnY-a5}d9@1IiDmz$2GeVbknTbWOlO>*C=aRrx# z->~JQ8_rkXtyuVqN9dOtYZY`=J|kO|z~vRNai|okzM$zb|e+-RHwHmYU~ z76hG;P0-RM{_wi|2V|h|F^U!Kl7xq<07mOffh?L4b_wT_gD%xw1u<(HRHD!>{E2GH zm8>Q*F*XFkmIj1)B&HBy(xMe5 z6gi<#ziCrzNYQQJuPL=WIS6cB6DLvx2}J1jQkhc5(_lX-%!pfzg#6(VEV%=h4Yz*lypNhai6f@c0o?3mj+Lon2^m^CyNJpyo4v`JNgR#&z|o zeVTv*h0ZjcqUjYyFhBvLjzMdE0|QwxjN@dp%dkK#(dV+}Gag`*>pT9^0b<+eO9U8P zVjmD&FO=J4Nc*j0yX2hV&j8bvB zI0_AO^W50-I1gqo9b1#89z7FinkH}FXsP`%hpp08dPhAoB8&&MT13ChP3t- z?s-Nk-o&IxI^nJ$RlSpZoL0MmYQft}s%8T>h(!)@$Rh|eE{Hnc)Ydjm^%2j)?Z!Zo4mg@ z==5CpZm|2C7NI9VS|M}$-NVjw?_(-A5@t~+F4Y9Q^AhJWiX7rgTj#wyU+~&l*WaB! z15>BfwZGeDinNZAVE>eLJgfNebbI>h8`0Odx<&{RzFxLO(oB|^2&->9Xa?-|*6-5W zZw(4gB%3OS>AW&3o}FYmJoNllv3+}A)bIugb1p6^ljPS9;6<>X?jv8`u24AY+g{#o zx~xi3OfabaK%2m(ynv6P+@;B%tRA^Zms!8Qx$!{n%b$UwLn82g` z9d0$@`FF{B1-*moC5QEQO0 z{P_J5!JF{R>f^dR@jDIgUj6Wod$H)hV(!F`Gh++NjIW+2=tg>5JRAA;%5|f~sihxQ z-}>2;5L^CcIsxSC{#VHuyq_o;cM|$`RU{Zkp7+15*B!4dP&CJ+f|Ul_4FaS7?#ET7c9D|k1E}~dmoc5g%|&(& zQ=ysg7DXN$h&xJO;s@=xjs~NP$<@Pf%NrOA5>P{#kgd3dN&>%B;joZX6oD-z00y4{ zzu3wk6Yy#q3VxpUL)sguO@K_jE-+1*9R!-cF1j>ZnRsLmy&hI!+qtZ2(?EzytQZ0_ zMfdH1>cG4I99TaeYR9Gu1F=$Ks~V%98n+txfj!A>kW#L>6CC_jHiQxY5zxDR_n_-% zEi$;OuHF~FZHbF>S=YT@t;{;6bVa-+ago=)1`AlyOMPImA#;4BFF{!)y5 zIm9CYh&5c~`hm*ml!x#B-6kAfT1%C?L7R9&U=AVt;#M?hV4c3eeBMi&|3$s~GHt|E0eJfZz{fwb#IR@Z=0XB^ zqdkIt)u9dNv6Hrw!}LsJ7mI`Nm1^YvML47rf$$_UY)aY7D>1LtnYzp|@X*dr$mf0B9iShk4laUKHb`-VbdWa zy%KnY>l4;9N;=&N>`Vw%WS)xdhy7}cHACbWZv1Y>Xl(D z=uIEiYuvVZazdIxAN~YY)t>ftim>Th^hG|E_coA(hE6Sp;e z=VOGoi@zPoD3j7(_U!lxaTVY?uBFgynZXy}hNU6697yVD69tV-BbrK%E22f~_es@+ zrsTtt32Ar=$6AEyL^aZ76i6A*NK$p@qMVn%m3vU24hrp4wZ9&OvnwU#DDK@6RDV|T zJh8c3uX24D+jKnmzFwY3* z`??<52SSJ}ZdJJkmXgurR@$o--M*Z6k~hig$su}oGW+=oWVv(jgKZ$eefQOGYziTM zr4iq|lN4e!tDkzwm;0;t>Fz$6HaD*H6k=y`x$|P%`1A8xgt&Uz@Nx8syS`*e3&DyK z`$^>uU3RGV`NU|w&x&t}vi`XuF1p)#{^C%>6g=iusbOBc!%W3MC6*jhkQWMPQpjAw zH$T(9VGDU>>AGVI6SERwbADQJhIJHKZ>8haBK{y%Zl9a!n#Y=tyi3#cw|x(! zgQt29P8ptIWG$bLV{-ahADK?2#RO1kqY2GKB3*ECIDpLY3+0(-h4FI*e_h~ijj8}x zZ9gctPH!Gk2)3MBgRGr;($VIei=#+&T8T%6u9)B_On%1zzjxRwNU&IP%f&{wOX=Dp z>K^a-`f8pH$ocKBL6;e}}8Fkx)s>4N5 z3V{TP+g}&Bu&-0s%~??}iW(zpuY9JM?@6Yq^YZKj!lzf_KqjOa3+a-a&Jce1L*tZ; zERtE{EJqN{7{C&$wn6k7oWMs%uXlwEqeAflcq2(p=E$3IVx6I^R8NRlfb66aj29xq zXQ3#;jjq+>%SZ_;?; z8!H)}b(_XKy_}6OiM$RKofZ&w?1QI6puOb%yVHkqesYfz^-mS(d*8T?qe0fY@Y$kL zru;}eChElkzg5!Zi0fT3Fy1dUy(^bm0oAP4DF|6y_PCg41NUxixHuf~7-dB~PCmKL z|MMz@wFEdLA@t){JFd5*kXbJz++I<4SdT${ZZeTWT{zz%N&WWEL0a&|;+M*_f9pa2 z(#OM2p+N;!_PtG~!k&vNBukARlh2$Dc)gwsMfJ7=*5XxvEsjqGbSijl5WOAo;9$>f zVNLq$f^PCD%Uw> z=l!EoubYlIya#)GedIYiBF}gX%q|W2X#2S?`$Ua&+}6a#CmT-aa~jBfmN6mWf{yku z(V+EM|7zKgY*qRCew+8C()d`!Grk0`DS^SvF`I1I8ut2-Z_nIAsV5ncy&J5o)702| zP>mYj!FNRnm>r46R@BLpfHUHQyhyrS?cP$R;gHINe2sz6S1>HECjHgR>Fbd z5HB8|_e#CCZ*`dflS%Mj3e9mk&1i7UuWj7#Av>*?JZOlnS|+N51!jGU<(PQA+LDwt z)skgOD{{6?KBV9~yh%K3@-a(y@6po5@g$ty&Wr4(sIJ%6jr+Ky8cmQj>Ru_zfw^>> zZdLTIE0*+JJ{-x~ERlFGetMV0J&Ap(~drfWUUI^BhXmAZS? z^QY{?A0<0$vMLed<*OjD=0cWF!I}h>T9}PebwZq*1-ToH-fuK}qNMyxwfKk3K*o$1 zq?H(Y?S-f-m%Umtbd+woH}q#_*CXRi_aF$Vn2^uVo^7+;M&#-7RI?iGTh09512vxy zlBrKd^Xw)Z~q55hp9uqpezq?e^83Bm4~i#}Q%;i?rl z)|%@~2aQ{au2r_f8AoqX2X;UR;x_c_&bi2XSq8I&1L zfRv~OOrZf8F$9~}c#`3Y0?<%-WHU#9vEQJae}cjMy{ZnGBuZhUQZsKOkiYPzrPZ5Z zCE4J!1e8M5l~^~xMgzs$ks+=Or0CT*TSCh;=wZabZ$6mri{}&YVPCyFoO~$`R_7Y& z(U2%d&TlODqVOWmIn!Aj@n-CG=7-Efd9olIUhc9TgA%AwL)$lg&b}qqoIX@|RnT_> z&PrHU0)I0MB7r78%}x04<0-l4xO%bmM84%ePAVH3e86A_)Hr z4dm4Sq9by4GaBe6W}a>zHk<=WF3b3R!~R>%e<@(UgdjW#4L3;Ewig5)??anb@6sH& z@80gdi=)Np9zZ1)5}7V$;)h3K?JtPJ8Vy$Ht_zdPEVGHfl<*a*5iH_7D(`pWss2?J z@j2d)K@tq|-fru6Zhc%X*Uf*V%HB>}z}YfcoTDgi!1HZLNKzY!BGMan`%|SG+~Sj~ z9E3sw_;342EBie>0drOcC|%Knb?)Fq5nUPfG2$`eTQ7DD`U&R5PuERHU@?e$@6#6^ z&B*fO`E#{g0vR=EV{{5S_cTN^bwrD!m2ukt!RxN)R~ufBU=0HSok|pHP;L@6$eYut zHAwkpfHP)ahM{!K5UnPgw03*3>&s{yUH~+AqQX~4_!?*M$_+pOy_q8p5^St@Bfg(K zUd?QX4uAer`}GeY?|prn*JAfI6PVQN!tD$*b!=Isa$6Ys+~!Rm0dGlI)(<`k9*@%y zvU7tad5M+d=i+y>FX3y43`UMo;kc?IjRG9``N&N{Wx=>{$L>>_sy7!_>T~42%=PY_ zCFbuhWltmSGo6x#5WEv?ahXktlX*|@3r8FB? zxCW37aPrHZX{?>O4UcC(Ix2bn!m?*(a7W``wp%m-$rx>{sHP6kn=EWugm~g*Q`NPk_`ifKxdTDF03cR| z8vw$gSQCce@K$JJn8k0{NTz_!$I0W8!jNxUBEI3QSp!GZ1CUwZF$%W?0%{V|X5U5- zeqtCAu%Ubja;gsig^2}s8-kPgrws%gd&fIVZ0V9?=eESHIOCl_M27415O!cowsr+p z1xSha>&$+Dgvo1?4GoFx!t%S>mW2CR7CBuJ5lUT>Qmu_q%t446Ts9_YKy=CcrWD55 za^8ktbU$_vpZgU&nm|)A7qIA9I(Qk3P@atw+8LV3{d-M^`55CCzvEVuD>McEv7)Iq zmic0|v?Jw`#|Bd%6t9#gv|I*LAN{Dm^5ecVGCV+m95UAooLi$*_&&3Ov82%6mr2-1 zL|WpodZ66tt?`BjQ0WsB2|kqf+h^{2{IPoO(cxyPIsf@#F@A#LK$ZslQ^DhOXM|82fnVK2F^~- z)UxVZVRSQ>Q{#|jQq`xa{j;HHrlC_nMN_`3h4(AGVE4mPrpYzIUJ=R(5EeS%SywEl zm^=-1i6ji5GQ(%MTf>{|^X{C8cgEOJw~ZDBurE5Z&RAFcrMCh_g3lj^&>ef$;O(r= z_!u`lCZWBip4ryq!I{x0uyw6DEs$B_f z1}GtOi*DuhTX_-P+^D^9#^K*anU(?b6ht=?=@^FHoZ}2F8(#GtH#*t8P5^Ow6Kla+ z4FcIBQhrHz^5?P>tlp37dGZN`#i*LsU9k$d?zQ#3=97G%fAz(bgQqXiF9h0ojD=z*PmP+t1V&r!y5rRBSznJGxNad@ohz*$P5MP;Z1O|zYbV}BtG5_bAOrFy&b2gHYK>fg$>Dk_suc*6h9Tya3Q*c5yxJs~w969@B zIz20*+5HyYdjj!eQdRTTkS$CFsDjdruY8`lc=0(_9M!luywVv^l6Dqqxta8{*K)tL z%1-we%Q|o1Yr#S(^t3tnWK4tW4Y6yT zM)YTUA;HPb?yPeoeHE6tPevo1^wgqW17llt;57zDQ{NH&aBBXb37Ix1@(Fdh_yv}(AfcSWh4$Vt1}dR| zq*qD2Y7_gD75tQOl2KN;tl?R>LOe@!vM08X*=<#j^FJJQ0+;+w``s(|mt205LqQ@B}!_ooGJA@nJbys7zH z%hEhOGq3Y3DkV!kO#5c3^H6T{O%7pGAC#O(*Np{eu0aGVLdI!Wr@`RS%D6>?-eR}) z0r@2~;SVfX!v&?F?{(kTb?d>X9wn}`ZcpaGTLvnuEnp|NHu!yfqfCVkf16Rx7uCYb z7eEy`fKs?Ym?5hu!M+CO7DrN?qU7ogY zh`FVPg6vr$_lBW>HIw-5&DO$r^P%T?=~nC~y%hD-GcZUpMQ)^=Gi?JjH_}}CXB)HQ zL0h{3ybhD|=I)O%z`DaLcsqTxIj~N2LU#;cIxGfB29CRS6k!Jq?pr0wCto*-S;(RK zA^1&4dD(m9*s8Ysit8zP`;U*EEtlO2lf85VR_s8+St-IY)MDd5|e; zxml2^JNc_W@T>_221gC9By>dWai8?zLJq9KW$(DNII8Lw=JN4CwcGyC>99o8CQzPH%{ zCI?jhqMP#r{||d_9Tw%*z6~o4ij;tYgh~h_AV?@kiAaMILx+^q&>=0TAfTYY5Rwih zF?83UfCx%=BhuY9#Jfh^`+2_o>h|{@-#=gYj|Zc2&$`#TdR^CfCcv>9Og8(Mn`6Aq zN{Brgw}zgbnmc7Je%nu(IjGT}q0ifuo1F*>V0SO-suGcTpi%0hJ_D0)Khkzj#FadG zstJv#?%{qdY_M`IPMIKVXt<(q8*Q~`(GXyO3;UpHoUnI$zW@64#!$;*g$r&AuC3sW zw>o2X7tWB@;%$>`O-|QJd8wqQ;F*8DKfFCBEx8%%?k&Ctn6uZ!*P(lB*koBzGRu)X zrmwuJp2Y5X4!cXEB|y*Wjc2t)OWR>R9609jAIL$19eIGFe>HT|f#44Ga2%C($B?bp zz0~*VhRJL$0rV2W`^&){*7rn1!eeV>Z!`)ggrc_>FKyzkl=^N@LvoByL5EA_cgeFZ z2=2$oa_~tD(I}jw-EimDxNPB1=7u+(OAeyoM=vPX7X*Qe8>UirxX5Viv2^oDKi6Bf znfa0$Lu{W%%1Xz}-5p?;PrW#{%ifsT@ekhTDRCiRy(5oo$iOT_b`s!Z>HrKw2Cl5H z$qD31qpnD?Qy(|2iuDoTTmq>Fvz3)*&`-qZQVb?q!1OfG@WHZK4Lws%UE3D)m^ioV z+G`!hPcZu=E6=doBAR=icCzb&RK*L%m(O!YP zU7uR*z9WVPhKSAol^ekWmh{)fDGm*-L(^NHF#e# z*VmG!tt4&F)LU?VvaI)Y6BZ|x_x`Sbms>%3CFli2pM$0f)WJy!6u3KDiVX?=^9=PrHU^I{UYZ<-^s}s2uSMhmwojSR>fuYe7cWtov z{EXr{=i?omOX05&y8fKyN*StGxi;sjX_f5;zv+8kSC2T#Rpv33<~l&*C}A6EvptS` zK$i8?h#2W>7i z*QI$2DnmqdI?W;$NN<4>i}o*DnspV$av{~=PQ0JMNrNyXkDT6@^7;AKA4q*%h({(hv=8v_%}#kWOr>cAMisFpEuVB z5SZlhFxo@m8Y)cC$+!m~1~JBIW6^!&{Qu?)#Wq+;K@<6_+p>(wc)r)S+I&1=sG=8+ zqSELmxegjPy}!E5b`|-Pv)u~MP-BYXWLuLgWm9- zXbfXkGrzXCCy(neKEW0({IEJ%F|LW9@-$5$MV0}I?53DbXMy%Z@h0Ns$7YGmX>@Sx zPo5lok8Nx94GcHH(16n%zWyWv1odYqtIAz^(k&}Cu8iH9Zs^$mC^eqvYug2 zET%=Z+*b}u^ANH29a#j1%%-vSZIV)}waByUfbca8v>}3qzw4>Ny>iPA(bX?x*?a)7 zw+>ea?DYK>acQaU3v4$7DOanBx3`dkB`G#X7qRYB$NJQ4*E7isHRRdvZD_hGmf?X= zy3GhRd>44lyS4^SVr$k%pdR72cCOWNJDaYkBDClQ8UL>RyT{s%)SaoPloJ?llV|Pc zTZWvb-7WjvZXy*x7p2s@sGU@Cv`hE3_IOv)EeO6|t?6uSM^en?%;;EM!MO8ty>8Y!kHdPJwl zYnx1%8PigQT9OG;0)qow4jpX-7Zvu--?*S~PUun}^-4R-$HPS-mkC$+;b)q=Lr)!D zY@j)&!pXg_mokqM-`dY@oZmc@%fWfNHJ$?I8zPp~58pM2ee-h?z5^jWs~E?x;>_;P ze&>#0*3$+JLX!7HO9c)kdT(fTEXy+<610Zq)SLqbrc$;c{G$ANGZ5KNKSVa_Z;{2t z@!p$dO;^i|&?!G_JUy*Y1n?z!z`)HW$;+`L=@rh;u;T`aQJj53(8RXd{8H*alaqgg z9F5sAlE(YO!_pLGO5zW?1j>%!ptVW%o8RU{w|c&X&xz{xO3=O*eci_A@Y1C|Fou=O z5LOn|Rga?#Z-P%OzTvxt!%FHSud;L@%Aj^%^5JAh=w)=~ySX!^vW{~-Dr5V;0hLqD zH0}hl*_n-FDW4-ghh+B7&U5G$KKJO_$9D0Y4+J{gbnZUh9YIew#|SWC-L9pUbCk$J z^e}@^BAMH@WJ_CF<|d@>#raO%Go#j{6-Vlt71l`$BL(UTdV%sZqJn)BDe0teaPP>H zOyioP$k7;;PVX?a_}iA4*#jPw%A}uNrD}=fTitb`9s` z%h0{LRY2Z5@+_4Wi5>@xkkIr}CkN!>)t&yvBCoteAhfTjA$rD_a)`P#%29Q3pfJ3$ z%t6~mJ?Go*6N6KRSCjJ|L+&A-#v{&poIhs*y7BG@AO&$PYu%-QjGx;+cSNe)b&q1n zCL{bV^eb+Q5e~y6kyOPCv|pZ?ff+#_&yVVRp|1jybL|D)rlJ|AQ=T+=8om!T@l6SM zg#)Py|1|qtmT*n{8=i8pql^m27deLmAHl50pb4d_sP@#I(+V`*VOYEC<*g)kv?Di3 z9mQWgc4>O|0Q-qTcnIisLt#77S}ah1nx18H+FsJn_iRwCcc~g1j^%_|obN(~pl>-t z{?Y^pI{~yOYmiYTwK-QWLcfF|#Mk26K(}Q_xh_73toiC!eCxO8W0QjEg)X|opF4Ca z#o(fPU&Rxla^zpoUl%V<7pX%dJu8%$z&!h&3q4Mpc8N_S?y1No=1objr=T*aeBWuY zU5b%Fl1VXES~*QVU6P0N`qypa(7FJp!0hJ$x&r6}kcaqyrK&y1$uWud%SYwQ-De_r z=KEGDK$Da3?$^CTJ7mEd=Q07A4j3p$5H4c04*@C52S2j;m-CraC+sW zwUfX+SM=y$&k57!X*F1^;Jh@HzE)Zv2AaWT07OO#xFfs2?CUCtxp-MQN4~-g!i7oh z1$jN5|IM(UAbdndpaf{v{Q8VrDs46=E)*NpJ!f0hf27DrDsG?*hQfKV{7+nvi(;(q z0=MfSg%?QA%fI1KP&)he49jq5np|2OqZjItj5maa!uH*fB2r*-(*l0~s)m8RcL!+i z8GYd&cBjL*aboHvt^ID3n(+N+vY?}!RG#UQ$qT+$_Sh6w-tPkL$X}x7*nd;FHzVOf z<^k~ve{&DuJK(UigRwy76j`ds?U7v+?TMzB=SA<2o%03kqSQ=r_C`LRCP4?A{o01P zFgKDUg8_ualsA+$u+3p-=mf7Gr5`PgqHrWv6u{RtfkI954{!vK#?M{pB+M?L=&iJ9`R1uVk3fPucPj_ zHaYq9o%m+-A&QJ1X(-x6OQZKRir0K!uU9BW$$zY}jqtSI%74;6}GML}@66vA~D! z86DNUzwGNMLU6;ZMy6J1%(^XIn`^tyBTFMs^@v4nGuuM3G_@loFjHh8y@b0ggmBk6 zq3|>dxi6EH<9CcTK!w%2I=Cgl}}!*{U^MKV3Qb6@8Ou z<6zavOFmyOHxUbRHc z4XI%Dl;H*r1+{IS%02ug9V6Ab*b0{?cPwt4EcMOVu2^Dom6biqm2n4exikkGVmmKR zb2S;kvnobNZu^q-&Vwwi%PDgVrADsXAw^$)v7IjPEJs5sE8SA;5kKGiD9HG4fdYzrnz=&ZC8x7!?hl;^^5Pg!-BzTt?Z@#YgAqH|)T z_f>r#sHxr&*R6j+#+?=5xkM9R9(Zc$&9EBPWjxgv()Y+z&thLfAec@PVue1CuD$ly ztO||lLM|@No;RJ^YLi$;ItR_w9wotK-N6MZR3f7|Q`>?N2#+~hr zpjaFu_JNa1AHCYSTD>^zLZM+to<$~h;v^2~&;F1`iDh6z+@$u#4x*>=q)3dMVnnV9 zC&RSMnpFADiGGEg-(Slt?-w~+m9d>-*Uh(ZW}N{!x6<*dcQO2207;)#iQqj>RAHLu zr|Dp&;ugIEC?0Af&_mS3Vxz5xd#~nGqpE76%iyJ zMi8EvBZ#lsZGzsz_Y}Nz{Cv6n$1ed#ZBkyf5P3`&vf@-vWr1j+kN@d`XVBJHmERuSS)T`PAjNVR|MYJemO z8+N>0u3!7mWpvlOqYmeWVWdC3r+i_4h#xVpNg2VX|Z zjufZA4VY^f&#aK~izSk*8BxPr5W9npA66TtPQ9ga<1}zu<-v>z1^a-W zjHHrIo*}OG+l3A}w}G~4I+x8LUpaJCTfOH2n(C3#o{RlgF4QM>3b%fT){iX7c_B!P zjoCE)sK?W%ij`aS54S&-&iCG%qEHfjh;>IBu}Gs)S+X3*uaR~!VsL53^6-!=yD56;!*XZUeIK8 zPVmdVsds^Q$x%tGQ+Ao|JnCdC3H>E77!r)G$PnV2=twtmN87BQ#9yH9&A!1vko8on z+J|CvE4z~*>jq&D0psHnSU>$ESwE2%3$m2$82IE_kk@O$sUymGcz1FQzS+`_Z_nOQ zelKI)fRm*H?(4f#4^*+yj2VgF!eG7l7JG^EM-1b8V{(};W|I+)SNHBY&4!w!x_!E< zqL!SUj&)yj^k`j_u2-#VS^jj^*&^Ks4VqyopOF&Z?!H{ap8XK)f-W2;T4z+xbe=(U zt8O@CN6Sh=$ar~+waA_^3uEoQdNi`1(CjGRykG0^kC zcnQn`E@6ZTS-IWHb5U(Ucsb66Gt1K60uU&Y*SnLZFVXb=z~W409(t6XzXRvp-8bw^ ziz%&}>BQgay+Yjedp-u@=|;R+nYCCKNR8MXCD?36(G`6`(WQTqCefG*ZijCw<5X-!Rk3Yv01AZbtB1t=Rr<@#c#jQcEnXlYkt)z zy3feqv6H8o2~|^^9bhis2ixDc+CWd?IYo>SQ?Y(B4dc`q_pROch|YW68Jcp%RsdhJ zysBt^L6Uyk-)KTfUH`r_=M9PMr!}&DufL1B^)0gh;VSHLQ3u5eoTxQHFlw#s6S+U} zSSkeo0pKn@#Q*;FZaBPFJ8hRkrqTAq3|r+BhBxKm5d^BPJu-6z{>a6K946y>Y(&=XS?zwIy|lJVBD=Vs%EHc-T3V; z?O_ZdJ-=1gz5F)|%$svPa8;9&9p3&X%S1qxJ;#MUc??#(6*s7t$-Vn4H!a3^0M5E& z6kq4$4^=l#1>2sh|NV@HS1#p_GihA@l^tMw<#iGZGP{hO{nzt(RKUyjc$q`O$1ZmKM~%3!XsbZ;cT6sy*~$$dB;D=gV2fT=K05&JM+2j}+mVD9@N`%(K)vc9MJ zMPMX$!e`?YsfLUp<7chj`K=VKd+A+;KuzUd&UGh4n=SrtJ?kc|Kx5%~ zhlOYNOqdT2-v6zVQv?{5>^N3Qw%@3viA0T>CWbz$D^n#b-KXhpi7n+1{mKmj3P&i^ z_WeOGFaPdAQezF8>n6i{7u>}Kh3|EB9sJS5e~8L|KK}ntXe5n+^NAMn&a>)^f&hp8X{wSVBF!Ey ztDIuboUU6?zJOXDQ7*X4*YNSzQRaUOnDa%9v3_{$9k>fB@St7?PB9+awerz?58q~K z6|~r$Ew|P{1s7WP>zsjg-ey<5qCb5eRDy&=x7o-xCp z_v_^6za1gZ70g?lJ$?&*=g69-J!0C)W-=p_IL!u5K7|{THB`;UM^l;|#;tyN6^QMi z$n5a4b#LxYe2!};BPz`*s?+t;B4XXAWR$D4;H8TmKi{tZJ{ODw{>Se5*97{h`JJY_ zTo%J$F&H6B$^Pj+-9QuBqL$>K^M>hQ@#w@LgK1OmsA42?uvpT+H(OJe^Hv4qEnNDk zKme9a)n-I2paXQqU}>j-N!9@W3H`AC##W4P0)~_YSCK=*f!gjKQ%Rl zsvdAa77Sgcswkc579{>_PyVsg?;m|OFe&_^Cqu_xv5N#_CpJWoxb0QR$2v7=x_Vk_ zEV>~Kk?tKUdB1nQXS&aVW8T3&iN&oYs8v`EAbXR-ATI@~+%`kolceM}=ejYC5-F>c zu?1&w2(q4IVEOImktIz4ADc-?Xpmh!U0L#s$d18mZ%%?E0OCG7QIZf^Bb(lCZ}=-i z`eV(`=P?qmek}21VY5E)$Wagdh#Z}7*6R6NR*b@mqT=kXcyhc1>ph=BOeW`w+RI{0 z`6AZ8Y>1_l#kx*K^Vm;Fr{7QVvjSrryeRg`>7K+7fNwSh43XLQXg*BY-#6KPTH$U0 zCPOaTPr}s)GS~-{C?U{rqSU7Q)=SE~I=57LPln62E^|6r^xBrhJ07$|wVG@VNh+hPIcB@pHS_Iehdd6E!ql4%bFzp; zylCji%FTxj@@wZMF>Q|xnu2Wy6x1oVl~gng>U;Imdf6Hg;$({e3?BL}2>@BK7R`sY?#$^^e#?ME2graaw>^wL{CUU8LNe>f~ZBGli*y=c#r$pof$Y4`s67 z@-8EpdL?N8Rib=`l`zGaX-U@_{l1kI*m6r;x({4H>IM8+P~_lX)_r84j?-*(egnOG z=^&@AZfH%@RI~Qv?E6~d`JBN8SM3rrmSH84)SBg${?~yuO&G#jCDDVz7j#0<>Vqw9 z%|f%+yk>1;pizVwkggQOdbxdtn*fecP8#bdz+MchS6qNIhO!^HgNcQLRw`n8fN?Ke zUcWu!-b$ZI3z%UbGI$iIal>6HChDcuc9ekhMm+q(jW5f*iXEc8ZgSPT$3%E zdSkh!8!EEnJ(r#~0k~D_KQWmfqL7gxGEVz zGIw>u-OwfY9QSQs##|NgBI{v6XZ4Lkz^xKFXoDK|>fu%fASq`0_eaAwMuv8KbM(1F zZmm5p+6AEe3mVP9dS4E{@nH*2Td=ePh@02n5r z?_M9^`LSw12_-TaH5LuGmzQpbJrX&H8wKzRh;Ir6PFz2luG=*4f<$B164$T^6!cr+X0y>C>_Z}l5Tm&l7zw({evg|O?~ zxrmdYKCM;Es2Kl=@(aUnkrzz#BsFEpHQLv`3^K~(`P6hB(^d0)bw<%T6tCDGil(iJ zR{OvKZt7=nidPvtR>XZGyDMC5&hwhx>o}lsEL$p{U#Oii#-DNI5=;paiRzJ*paWh; z*7@CV>f2Wj6FiG+4)TPuQHO4>;{mJ_kFn3>em&!PF5YSu0Qd9aDl8mg;P3sH$`zKC zmap5j%aUgsHV(Ut(mjI>+0mrln2Iu5=;F859^Z@5wk9L)t>et|f&0_DN%JPYeGpbB z3BeHJa1u2$EqMbGwwk$Hdm=y%fQVVmNJ!u{MgV8`8JAke;0bmBdcHf875wHqv&7qx zbv1!_JckF%-h-Y4`U5uUI=nS2O^}?vv6c+A%s@|z1ouuZ^jCj|iCBjgGP=u+RKykL ziPQO?IfI-nr87W>YG_#4Wx*-l@X%7T6SB1ig4Tdxf|F*bsYSW8b48?WmzW_3N7F_hO2W z1TZP@7QWj}K_Lp@N?PK%l0@E&S>NYG%Q)k`N>5B=)gH9b(YWTg};}zW8eHhP# zHs%W!Nt=6VV<4I7Z{`HX5iR=WNVODFGz&~Io|g7PyJvaj;oiJG(6I*HVyIF3)*fj! z7_dZ;BrDxoy`z2cPVHsDpH*o#mIPcDf|9){l8Z}gQaqD;2GL;2csuR+c|MEqVQ7Uj z4BWH{FQVxV8gQXQcIrBI&yXFqZQ}U-L}zHQ01bo9f;BOfnFfoqjHH1iq|mIM^x(jp z4-Sy+8%59ymjMZH$;BwZT{MuR8nWIZ#I?8AR}*O6WkKn7FyI9oWDgqfvqS*Al4d-^ zLN}>##I`e_%=(o}~*~gi3it;j(e}RV*HW|KRkz$wI4b?xLPGbY&3hLQ!FTWy{_r{ z+*HPh8Kaa0qcz*ChW+9xE%Z=Gyb5I5?20^E0$UY;Nfm4-s@ZRNIc8neudZ@~( zvj&!-x^C2dtt zwl4LI<#73CyXC(A$tuDAS(Nb#xWF^}F^EF61(l54P+qLrAV2l?{CK{RZ8|CLK%_foHZiR5Vd zsIb$=Si>T&x%Z|t#2mB~Q8hC6W9~FWv#D5V!;S_lb0i0cY1b14uZf>_f8iRn#4lw@ z(UyzeIJZ6~GO#ofWp063kEsb9ynx({-5Z{I?^KAqVIyx6IkWG^5fDUZj|2Q`mhi`? zCRh|W0YY{1GG~hfD52w9Wh5+ZWOPRCS*6X+u+&4sjY>BX)4ilp&%O7zN=f8*J}wD} zCzc+7i8xb7z3ZN@cC|~GtWjR2a~YWkBl#Id>mz{2j-}FEyCm>7rPOn&)~;daak zB?DhB>jK@QStaj$v^-@*jcQH6U>ZZAsKTQwdtNHSOrzW^@4{`g?F6ElIDFC%$)e0- zu~h0WUsIXGVLAH%)r<26Ym;kf+}(!dl_6^Msp0}kE@1SiLyk+5GemexW5}9f?DK%b?lkM=Y3r>1mxtn`c$ z8Jno`Y{niqMAo=K2}tTp7YlIaaNJW=eMZCHjx^_xe)p=}oZ2&}wCI?f+leShME7%t zli^HfME6(vhMj51karMpDnL&`S(0;wK?gkXOu>U%zhnhR!Ar=Vx1Zx&M%eN?zwJOZ zxW5W&2VCwiTcmz-gu$)kfer8q6}^NS*vVNvql~Mb3Rh9It3VWi#hRNtwp;M$c?r^m zD1J^txi3l~2|`OTVOF{0tf2s43-84cy`c5dvf9mc3d;k785W-y2Pyg_!_^(zXRQq( zSxiFJOoQ@fvW1rv#}S>MqiPh#r|7&C>#NN97#FpXiQ^e>>%8)yh3)Cyj&8AaO~P8w zBq0k?Cqj`LNtDut)CY9tPE^QZ3ZZB4R*gVm@BD3%?WAoD&_1NCJK3e81P2`MqW$W9 z>^DH4%;;h@*!;bc$)pgGd?Uq%48r*dA^$_KnCK=&%iio}0X607C;fk2>isCf|MR0* zAV~S1zB-NyL_t)rRvmrD%rBZsI`SZf?{$IV`~o;k%Y>Syn9CoVKGYAUHQks<#GaUF zp>x^stQW3f%~cLxE(by0om%X@HbEgFq3i$=Thy5EsRjPHDTy>ZnbXKRKIt6{T@-D^ z3USs_Q+U^Vr9 zB>ydhn5v$$WW+7=)<&DeZYpkvm-E}pad(?kr2=8l@w~Gxq2yjyX1pAmk|fUDtc&Nf zBTDMa*=rCzyy%S{99>{J;Fz;frdGLmSALUHs9QBb#IsH{Id3rRpg~ohTm}A7k#H)N zSkC#v;BU;gG#uRj-aEc@aWZ4UgMRebhs(#@Y#kY|R@e?LP!fGVymY2_Ie7QBX|wK? z?hI9axTfN3Exyd<+?PTIGeQZUT7x16LSWjqNQ)2BhliJ zgs9CJrh4X)#Dw8O3+L{(#tTa_RL>(QZ}8@}Beey%RVJ8_dk;uX9~Bf}LA019Ai9nV zefI&F*g4(X>{t*G0{RetC?(>;M50KjQL0crCI6Rj@%HWdYBpH|6EN#%4{!|!O#xt5 zcC7P&c{|%10UK+`$oF)m#F`GK-36ooUe{Hd8^W3irO;2gTkS5>}J&x?-Wr zx#Z|L(}7DE;9%lLv+hVDP>8AV!46V0XNGT^ras6&NyOJR`V=Y-KnjAYjFE~)vLVdG zS#^H5cppDVy^;^x+ggNR{Oemz;_0aL5JMY%S;2^LMdeTap+l{oUSEb$(hDz6CwS`_ zFlHv!Gy%e<$r@~uLiH3ipTqrWQR+v}XdYcxx!r%T=iWhj?ssun3g+g#>-YvaE*7LS zs}$a^a;oF3ZSpV~^(^iYQuXUnMdBuYf_!+oXiHJW_LXhXj_gGrdsXz4=Gv)1T)u!} z)Xso=KD&F|K^81{skBO2+)<2ZWCsuGm~kGtF~&|~CM}j0)c#l>iM^EPILC5T_YI zyLSIAOwtMn$F5jwG0=t3kd`eL4AW&(Ata z0Mb3OIL?n87Vo;{=pT*wc)%Y5g%m^v;`s%Z*Fj6pT{Cz*^H7)f`CHD$h&Vb8yWT`2 z*|dcA6aLWYHgjV~-Ed~L?bL6YwUpVSnX2iRKrZ!y-QSEwNzzNdDYciJ4>C2b_*OTUZ|&8ch0|i#AFnDCC#CgDYpx zdj?kw(`9=f3BPJoF>M6UktT9CbY;51$iur}*kl_&NAk7pqg&|I30S`cDk(Gi%sXT> z)s^lbB@m(B#gu9a2c9UxlHLcZDBi4x7Tfk}_JxZrohymI>5%IADCa#f(iqXL>DZor zdvAMr`Akrc5F{7l=01b`Hr!?}RuzQ8#O9?K2mnO)J(2wt3C&`w=wYXtllY&6mMbaL z-n6|0Iabxyz~}g*Itrz5%bcnn-P=Bv#55P&X5Kn`FQ_VFsz!@wW8zLqjdn{w_@F3y z;OMSt(Vrq0X&f14qs_jg@PtS3AQbA+ZrOG%)Gx;{(l@ouu3;uzOpvdQ-hJT#NKyEL zp`-1R4G8D~n-yD&2K&OT^44iD;K4kO4z}CR$?muKVnI$(T_r{0C4UBNS6rAp&Wgsf zhv+IMt3A-sxWQsk^Y4X^lQ_398Rz3?xD0`Y3xgw`QoxK+ap+yNd9m|p7_m`3b2?o+ z&2*#eijBR%WMk>-XycazsyqI@IVy6&qEoi!)~V-Zb)Y6WaB6*8O4yR?RwJh9+di#Y zwMv*(>OZHYCJBM@0+PN_6H3-x@<9y-)dcUuT^{p}t;Boz26Zh&Su^1`X;Rx|d93#? zRrU1MGy&{52w=Q2S>Wi~7SvU)4;;*#9rD~BMqCU=Tooh^iSJavcNlS<4)9|e8mHB0 zZ=(D5@nx7%h%gl7B-dRAE+J97Q?!^2T^L#`2@1Ok1KFSlk5<6?EOIU>Xe$WHAyYi_ z=;ln~lZEU^eJ@SU>YZ_8|H;9KK=;nvCQv`&0tHkG;RHWL@<=*dXv9^`OMt&JBrbS1 zJuN=J+m`o5d#Rm~P|>}72No}XbY)WBL@>l`(Lq#ewhlLr*KzXpoq$obdj^s@RiNOL z@49YUHlJOw5H1}68-BRi-WGVMoqCc1X;J%sswt9b=Un$^1&l2QiFHFefCjG;fWll)K#2%77v-z|eCs<~)bhd_!=p|HBv>V2qU4Q3ciAX?N^8m{Zj0+aJO zsN(6F!4jJ;maS$$nQVj2ESSTW@|R~ib6wz%Zxu)ZI4)Dvz6NbuhhOQD{46=`GgD5Y~|{*qDy2mxzp3}hJyv? zuIz%cT@l-F9S%EiI1>t6q&YX%Fzw(bh6a+$i^JUdpR&qMEy8+owB_73heU>-S1gZb zQKS)OV5*I-v(d^TspPeZky1?qn>V)#@r$fDhvgGfc&R{~#~i71DW`HRuqq3aeB(Fe z^}eILRVXN-rx%YQk_HJOZH2buoAbnSI*#p+9`E42b3#p@rPi0*hxPRf@<9Zz-9` zBm3I!8a{g-splrcS+T$~IDv`DNQNi|=Y>fw+-SqQv*G?-E!mj`6<*>II=-Qdvu5r` zyd@<*%>=#Q9?0cJtaV@rm<#j4HOoi7UUNk7tu<*{#x^=ZJa+X%8%D^|n4wo2n;~yh zBBea<3k0PXc_b)V>)?^lUj-m`B`2SQnhy)T^4P;s_d$^2!zj-p{qk+Uet7pG82tgL zuPBkg?6M7@q_kEa+5@pmq35@Z^qI2~6fE+&*>?DENJNa-R&@drJeSh|h3Oi|Tl)i9 z1zQU8^T(?%*=3nUVc_rDxh4Eg2e406?#*RxxiBoA9>vgv@?J>$d>vqqRgt(9{M}LV z^Dr^d5eHCxo{gH7Ua$msQ}4HW$qV$z0iz7BM!93OCt$?8)z-T<30bhL4}^o9K3m5` zALNH;Hj3Pwiakvp<`8ZanKlp16ge-Ji|NnzAVc)#du<4YENlH@8MokJ1>CkdFKHc~ z)Tvu9r7~_6nXBl%T7XBE-0W+&T%Y3bX|?*IfO%R1O#g z;py}b1})sL7s+1(*8I9(F-~Wy!l^Hg;#;xhO8S#c@Kax~x8EVF%Rt#EX=~4GZKBZ` zl^#JFm9FXw>P7trEb^35V)8oL2eyN-63Imc4W?^fgFHrG=iL9|2h1aVN~Vfk`a2Y^1b3pdc+`-D#+{k+8!+q-zI9I6 zKm7(p)V+{l8(&#F*zP~G$bUp-APl?rp#D~4SAEItKpXU;`o++6O6bNR9*SQUYHE{1 zu3X@zRyL%n**@#0TxQiY2;igk;wRn|ak}$mb5X|?cv^(1dW8r@X6BBIESn+WHwrKL zb84-MXNJzGGF!I?QV9x!&)o=+!vSy4gzm@rJe~oG@^3Mup9z90*aku8uD!+9rmTfQ z%;u`V!jH>-o%b3YK!MoK^)eTflh9Z4u21$PCKX;J-v zbYVSC5mFbG8xWNkJe;O>r>h1>Z{eYEFMxlk!=YLTX9d+ukfY2%F`y1E zUVE`yYaYVqL#i!vC$p@yzT`t-yi}TH`y4-jPL3Bs9#~Thch zH`B%u&43bQCb9!DsE}Jwsa3&Ya7=jvw>?Z|S68w%%%6WHNvO|Y@O7-W2 z8-15IZ6cv3(^b<$G559YJ3^~<6tKw)d>l*T4mMhgQkfl5ZTB~d@=&K1m8wjuph%u% z6>Q4FQuf>B`f0Hp>1jk-AQK`IOc;z0M7x6PnmU}C1zu&hk-j_x)!tx-tZ+FKD2l-L zggQq92p`T$X&2q~)-Kmg8}5+IJTa@mDdRmzvBL1LS6?S_0wH40;EUbr9vPWtgQ}|P z?GDBk&nOjiIrdFlbtc0E{Jf1b-ICc~r@(LO(Fw9F2|=^ID0a_=#W$#FaHE8@-x`22 zpGJU?lA>ZT($Vk()_wQ&nm(vbG@m8(8n_cStB5XU%0{9M=uHRNPh}50ukfoS5xsQ| zVEYSQogH|piFF{-n+jNhWkTeN+tt1kS|Qi$R_m!7)C#YC4I3-$RJ;z4gfd&w!!R}e zHMO(c59qYPKz9d7Qfu_|!2!5nQ3MkfGRSMNmXlO52sbDQ3>&lG=5X=9JFlv$VaMtY zX54(S&)i_l@O?Xq>_v3Ted!Wk%1yRhWPr`@E9@|uf#h@rgeLebsuVD#tX#v?_j%iC z{DO+I`BApF+W`D^>3q-pH6S-}cOectSDC8>4iqVpbG=$5d$sBta2(_%vp2;-Dd7_B zfk0NGBm}?Afhm(Gyf2#xIGYQNB5C7Zg{=zLn9CWexDNKgYv3KAw?m5ON5hSSpf(+{ z@&NEJAh_EBP(w99CD`BT<~!sO|E#UxKY!<=N`O(}yiUm@{W*~!8~5$uXq^X=%XS-m zggl--9xfI3W-(M9lRIbHA0`Mad*0A9JuGAttp5trZX#?b4iSl*6wsO!%M6Y3|7gGR z9zO_5!?uJ_lm4=C(J+;S92uWTsS=nUb>i`t*qbWyD|Ju_wmFvyBFKPZ>)Rm0n|*=Q zS{EgkYvsk2!{u@tqRcfh(Bs8gK#SQLZ7jfaZ}Hk)kj*0RCc$N$dk}E7FRH>F@9tmb zRw+3gT62OBONc;QB6(!Si8aiyA)yYNy7l&!sqljW+)#4=YC-yAn{nShHmos?+? zhkF7({i8(Wt#4bDCc?I3Vr5_bXeSb%oXJPML=@UJgn!gnt753)pH8PQvg|F6bjMsB zMk~w7KW5A3d07!zS48+)Q=KkT>-D=EsLncv)${+^$DjS=6N<@FCuP@cZ&3}GR0*l4 z;zJ#2VlPUFzM{z`xSed$>vuWa0l zT^@-GQJ~W3j}tE4&Sz}{4B6b29up((%W#lM7A)u*kmX3H?b>|}uVY~;WA~?Xj;dB#WXREwOfA2T?A*N={4HXUz(f4B1bQ9+$@V3g2};UDJS2Z zflcO0oOw~aIMZAVLp3HhUZ03*6m9B(4I*miyD~0PC%9*%D=Tvke2nSBblX?4J(Q2D z0j3Yk<+T=eOBDLb+hvXH>AZ~{=xAYJM!`q8zd?>5biIJxnHq)1M|#Ha z=fTy-o&6XCIBf5b)0V+;7bA6!inPvSuYA7fs0wtLcmED$1nc`k``+pc)_lY%y}*WTn{#a zD`Z$sAsFGt067RTYn(50=Uc~R*+|d^6<9G!A`*Q?#XELhWJXvQnCAB> zkHBO3xKN=D(hm2s^!&B{$S{-k&t<%24;I^GZs#8Rw>c@=x}W&DMS{O8IrY2tnC2_tu*Mp2yzMKal??a8w2n&9=q{7 zn0v!yEk2c(8EO!4Yh2*CJW};Sk1z7sKi2n;k7DmJqg1rVqg0X=Mh5l6=1eER}!(vpX2-DUkQA4F;ezyh}l`*D`scw-aSgZg194De+NgL zF`ip*40*+3oG=ja_%Ysxyo!Z-PWQZ1hLZ)`f-Yeyym(o@o{BVYSVSa@2CO zyiIq`C^`mIpQ{gJP`fegsg?2}<*n=G9VW}G>v?Xy#3J+_Biby*^}MClE4y*-FNe8u zmUV;Gi(L`|=p!Rq8>-E>DyP0W4~~fDWd$N~*sLj& z4w>W{jpPyk`Wu*s)i_d;YRq<7nU5tGdRQA^K{P*v{_dMdqq4EEAaQRvh-1SeT8ZTi znB_Fcf4#6DYwlDNGe}V>Nbw2lSS#}(%rYjF`}Y?Dzh_VmMx-4rig|67wL}E2`M|F} zkoo;asDlZ2#pVb+McL2Ocr3Z0%vTBjW#7L4lg}(WV{S)!e@Bgs`R-5+4mAHM^KbT> zl<}G_E)GHqnM6eIjlPpwyN^D(SA zqX!HNR1W-JMHsKaaDZ^AX?_l zv3A}s89khyo5DR6qX%6R{Jr)7Kpg(1)Kx( z4&54!3KJ0jt@4pZJq7ZOpN`R3FLEan>fh@4p+Dy!Pi{{v6C7jrpT7 z9c%wzGG(KKZ$R@#YxfT?1?_fPrV9B*>{iotYQK3@wR>+KsSc}Lz1uCk>>83T`u^4B zaw-v=diEcFr-pzKSPJ%v?-684+0J9Hb>C;c?C0x#;li~jo{u$oE>-Q)kRVH`Mzr99 zWIw#o7$*c1BEC5e5pPp*>Zo$)+i4fc$mZD;_osH3&u51TfU5XUeS{C+CEEr-abR6$ zVpqhpd06#+Z(c3N#w&-T*1s4xM$#djCeikwXCUw`922y;KG zZ2H9$2?{US9uOD2S+}a{enXrEc0$%kknr!F;K1xe{P9l6dPukaD$!^w2EoAKc#G7J zxZxZId`fXR@_T5`a^=^dc_Iw_$Kv|9kRQJAS~=zg%DKL0918b-opFc_g8v8$5@PtV z*8l5vCjL4`_bCSZK0xIn@|WPuU{a5{>f_Wy{9jTJ*eQanLXIVmAMu1w`}d5>^EhVn z{g;etfQnSZTzhEuuiz!=yMUpOw-pEX*EvWbX8(&z*U5ey?_G6FXtRBsS`iZcI<+#> z2LBOeH+}hskiN*qBqK}5$q2>SUne7f4)V`I{?W)k8u`bx{;F#A$F%BnJmBLS`T92{Rg19h{oVBn zf5GwRmotR7`jXH?*jScY1{qioz|NjNvKs6nzhi&-Tz)rHTW!#~0(N~FO#OFb)2}=x z{G1i=5_NR2oCyZkcDv!~dYZ7R&7N*4(Y^M$HVKkNA$W}}gFh=s8q9P-P%feOJ>j$T zD#w9hMndNDLOiX!v{mveXH#R|klawQzlQ(w@}mzo$asvT1J(DJTu_sdTy(+Yv-@O zmZH~;v-@~{SmM_HSCX9Z={y&5joI#%D3FCMf-J}J%WXov~H!ZhZer!mU&F)hQ#djmSEiPNxT1FvY>m$j zMM|NENP{C`*2MPpku}=UosUr@msfqh`jB{xJ(A2~{YBZ~Dq}8BQ4K)5&~pM%DyM~K z^xPNK4i+AY(v~A%!m%JKX}9Kl;Gn&DUTvc-ER7%wbUNo?rgbwwOIU&6bgV-LYJZZ; z)xil!OIjeQ`f19-)U-z&DDaXQx-P>-0|rlsKgKE=n4Auv#0WFymakU#SeOAMzX5Xz z?$2*3Fi9*=(RizCi^H1{rnURu@I04aq9Hkz`2()gaqCmiNU*vFj@%9gjq~B2^V*|Y zRm+~s^V(JXgI>$0$bKQ0S#AcCA8Bo@I`Y8Ht5;F&86j2BUFYcp%*W+0MYr?An+HHm zEX)1e!4$Ec@We%u*)#Xg5)J_+Z5=j3RxFi66b?VGW`ha4IVJBs6&3zqYKIpJ7MFgu zzcsq|#SBd5EKGnjHQS2q4y>`nSSi=rGNr{iYi%<%0N|P^+TG!o!X)G8?BWEK?1()o z(d1;$%RrmuT6VM7?i<%-JAgxTG>0#`BkqC{sQgMi+)LijVp1-2`c{;%yR!rFEjOi6 z(->GrX|`96`>>@!&Q~ROeW?ZR9j|P#*f}unk1QUr$h6oVcF9?CZFN5*dQVHp`*42_ z5yWW`ocEctBFkB62~6slE%5;yp3<^j(rsG)A=_HJ)dHi$UU*T^NF1xhv@7bpr(bvd93IpfSiU*X{FsgyUML>4_FU!~ zXkMHAy#D!J^;S{N%kxf&6?yA5!YSfpt%RiK&ii04%@r9~-rOAC$yttlm^}Hy?c-FKz^53 zW4*lbVmJ1w`RoMEhFimF8(*14G`E}E&#(%4t_JC6sq9r$j~*=by1ckyD$CO=Mdzey zktU!pXFDrl17^3}d($DM)kZ0+r{{aCmYmY{)VmlCSNze%2lJ@I`O*eOhAZ=tRSPCg zdyj|JS@g8`Ug5J1Yuo4& zK@`P^f)ptVD$+y*0U-%0AWBt{BE1NRfC1?xAfTXhMS7Foiy|F@NK>lR&=Q)oAOuK4 zAfcQY-Rm3QyZ3s}IRDO{JqBa_&}D#U&S%d1zS>0rw+-#yJlf$66-*)FRYuV!XY4mE^R*T>;7}e!O zj$fMhbxuuj8_P#bdWZqLjMeTdai8=$yb@$)?+fi}p)^5!*WzbeNIT%xF4b)Q9(G-6 zLTwp2c?z}{?6(yxF-Hgi&Z>IPULm8a^B&+BJGg8aaosvQc3VF;)$1v&YZkw^ipPQE zn+1}%?|P`jg8t_x_b@jSZyykJppf(JVFkfwN~U2+16fM4)$%sL z4jTJ{LjyN1xVoY)WeiL?Bw%g5pL;^bwC0I>+RjhGq>Mmv(gQQA3B}r6smtNa^oPS< zzu?vF&wU-=rp7;6<2tt=8{2F1qq(C6Rx|h^1}OrJeM8h$mJPDF?DVE~!Rm+U(HXvoflkSJ`j z(2+wtSE0I}uWzd*CL$lP@-O)Ce{!Gk!UteQ>M@UXvhUp{Q789N-d_;Dgrsf~j)U$J zhKMiGlIf~zaB|T*4ju7YE-qZ`JFn&BaGk+yV*oYRxI5mcv(y=naTO|iusaR$dy<-a zLvbKl#eG&5lF|NtkuB5Rpwj-4>(+Fz3>xE&hIn%?K*=&x{2@&q~nR3%FPt_&`lQ{Y@vTJ8o8GJU&(|Gg^bw;OIiUB^IzqnmDc|$X{c_7DZIKfdm4k5 zcqyF~C?BC4;p-d|AhB%=?zkZk>huH23(?IEj=6Gt8H;tL9FI+{UsM(^KoPUWNh^{j z5Rp|OW}K(yi!YacD^42zz`o!jUOyb)M057pGVFK6!Qdwgkd(^try-g<4IYdC9a<^ol9BV1VR0#Twic^JNmx zWrM==(ICay&mNA?k=jhLr$BAg8ZdxP5$l;%f`s*;A&ew8N6U( z4(|3II9)CBVh0OVN#$v~_8D=bRboL+cVW2?fd@fvlHzfdL!Yy}SylubyLG}r9os7i zKlvVBKVzQe#+I^M5bu%Zz1mNN#cu+UMLn{!VG_%hPtrIvg-G?-DzViA%9mzA*n$0U z+#%f=wHlV#gkP+RVy#g(4edX+#`Du&%69tuZNH>?LMrJ*5T2ce1<_jY8mh2`M-j)S zt>Z{${?mnRC2daqH)S0nrrRApi#i6!j}}$uepY?ycJYQ87p+VauKGjHNc?`-c%KFN zLia{0CZS*{;k>KZ$V<(cKtJNsjf7sGoVyV!{qkCGxim5RbGhAD1H|5{qGm!wD8nlJ zyVrX?H6`Iz^p^bVzQCYOS~;Z^#GrdfVNs>}BArlv;aB0dkT8g8sn<|pTNX;Z?$;+_ z*Y7g(+`9j;0%vdhCLr=zBCp&+Xic+3ZGqw$%e_-jSW=L|_zv1in%(+3JJn6ChsA2U z3MiClFJRBn%&NFtg>4&mW znwH|vhqQhqN&LzDY&+a(g-W|2hYb-89%CusC!c}cK6|z#U*(FxEg>hZsYJVyrPyWU z+Dk!b`ck?;No#L{WbRV>1uwCs6qB_8cBzJFVhI(zh{%L%Uzw**X%o#(S`k=DjbOi! zNWV@ttk#no@WuW2Bz0ci$zYFK9Llb8-vWU$wUlov@OO{_(b>IQ#Glc(!xu{;#edws zao#86^2~0K+TiZg=bVEgEd3$8n*EN-$WjR$X=^r;&*!@Ka1X&hSGertfgD<^SBg%5 z!dQx2u`86ae>h+wkQ@`#S+!f>unQt&Z%!> zC2^dR*29uEJx}K!l|L)6L@nivc0;{4ljY_?grGD5Oy9u8FLWnOb`cgPjoN((-A8K) z0Srf%G6(k1Bh&Zn|`Fv|H7{Mf2%2ICRXfM-0U;WIqt6!(YuCxcW4Yiwvl`J`5 z2EDK5bTn%Di^AS|BvrEx1gwx9w3@ew4tzfE=paaH?m7QM?o06fMHPjyDg;xzOBD(3 zK^aKHR+&-)p5W(KXRG&8=me3}9CIxk@k0M9eETL-@gPVmGoM7|t0dId)Ty5fMIog@ zEkyyZqp`cUZ7eC?@B#AJ2qKIkCi@%>fm&3NL<88X|UT-K1KaU0;uC>n?}VstOm zj?YCj$|7)ckGp(7^1m{oY-*SW@zk0YslzJBD@iD$q*XsL&T>Dly*$nI*Q+R@k%zfa z>XW`pN9!;v>y!sta91V@5xm0>7IQWIKcJA7`>fz3!*}RrnJ2YzFWiK9-Q#ghTo^3W zNb&l@S8ke27%KH8Rp-Xz+<3`vXOcZN73spSyK}uOtb=W;H4BE1E&wYZSP|R>q`y~} z?|S)umLLpNFS6E5MZ}fm)Ca%2@vv%I31Au$&wu36)tK1ch2iFoEm%?V^9~{8kvWbY z8&K4DxnQiqj&Qx1!3wo)PJkTNy26a3U$O7)vTE9wZ(Z4q$m8eX$zOmnZ=X*$8B)0o3O(YRS+D3+Du;__SLDRJ~*4L zY8q5x5ua(U{p*a?1}l1gmJ5lO8fjJ~UOw3na4hLiaQz!zwG1+mXf-5JM@uLwnMw3c z&9I4Ed<*jm`>uVg(FK^<=5vQEc?t0nu#%tQLN50);8ohT9Z2#fQn9I48RShY z`YpK*fE5MA>|tRq|K0&wGfaJKk+jRuCUSd{uZ(gR+wsxemy&4>Q#;w6YP-U`DA72>7z^`btI`9W3@+=QuPbsDH$8Ugs{j7pRti;P6 zP&A-r)H^_^tlWHOe9e!Cliwm|HKof=DiX7>@C~*62-%s}xKn9cO62XSz{BNPyfPH5 zVNvLmX6m7XB78ejjD@s@(bIWyg9eN|5EsL^uFv&h)5B`d^as*PYt;sCV#(7Wk&viw zICZBYS?g3dz;*Q6xt-O0$Bb#;s#&Sk@xEsWn$5wyeo2v!axYEoSb|F%r;Suydzmr_ z9?`0Lg5%6mg^kKI<6QJMp<8p}evh)TkG+N;cCRV2{ACdHJ6$x!tK?A(KL$B#w7xe< zNF_g9E-Ke_NBUkFTJZrvQ0@~>DiuwK&JdC5a<^ zB@A3rRsO%p?;SaM-HUwsZy?I+HznX*t)Rb~m_O*~b9f6RE~VW~vO<_$@Pc!4AC>a5 z;@mkg?)@T$LDyok>AibV=)R+p)|I!F;gEQiRdVHx?Z z&DEAq!IDvE=x*fxRwS;h`FLki0m!yQsVW|FO@qMoqd{yoKxuS{xBBHr4K0Aed6gaGT$BH7L7PSY4fgG!aV_Ex$?D6E!TKelAVI4-mzC)vc!w6v zv&RqpRyz^^z>**f_6Gr!W0UzaiqoKmjvCditp_H?g2O2-tdZmq%ShDKzHC!s?NzK) z!WUH+TOjY)++)lnd)OZ96o%i){dg@2@-{F%0gHrP{73v`3@AT^v?I?MW-P2gSLz8p zG@vvgZBr~WH%uTB`l;?&s90s5byxiT){37^wgrfv?dyW4^l65~fjTAp^7e2QN@t~U zf3Kmi4H{uL-MG~%Ub_tPRkm=TGjS(x8=osXY5F+ao0%o~HTimi5S=oMC|B~u@n+Qt zzSmf3UH38bV4i}*b1kelW3j0=8RYjIPp#{1H4f>b_DH3XoX;*bMsVz1@+#)z-Vh41 zmIIk@glFB)3&yt}((77$%K^ziRe0An_FRxXyo+*e(7FibZEEG~by?EU-+uHZ>PE`g zAI(bmJu|Br@Qm2ljpvJTgJ<;L5Z8UFD`=RKPn6MUaNl=&-I*6}Ya`MFXsvoNPUDBW z*0ARZK3$76z4{rpz$EBDB79#APJt(wIcXJVOraBI>3VG<5u~xMOp7t9LkyT{o%fvm zbF5q}JJ}FvUooV+Z6(!-~y~uy!uiFwx?7mFBlk_QUh?f4xEj6=^3^1)Ki&Qh{51EPLMuEcvJ- z!Y0~ysXy<%Dabjd2JMg9mv>XVFiXhY?e)iWA|Fe=RUdV|`^*165Wp-A`uGWv_|k#P!#uj^#ro#Pt?$Gl z8(5*-{tls?+k&>wj)0YXb3^5rDCERyuZ2|!0{=!Dmc0SO9u_=MiM|hnp>P9 zxViYq=FusK7YtJDaz3?M7#MQC_{}}A2^+e0~4_ha=ELATX*M4?) zE?;@0KS5$ulXAlE!E71n(kPTw+or2A}gnPIMKh`+-`leL^eY$(;{Azc{>>TE%^xe8)s*eS2tl z=W6Z0&d7A{8CQ~~gbP}F+Ze}DrFz&r8?wg%e5ijdUn8;)Xy?(gN*kZXZ_McB0<86wXNtrT{} zxT0Or%2(!J3q&l!@yHX+eT}QGrOPFhaex8|d-3umUp%Al3E(JGQjqY>NS^=&z<>K# zkVQ)k8sARa-%gt!^SE3N7GV7qx;io{_vq#74_;#fsj=WR&Xp^B4l*osLDaejnY+&# zpctb_lrZFwH}MfjJx%gHK}L}qf_beSzzYn@JK@#fpSA!eESJ=hFP|CdChk{ z>W114Yq)mFV5I62Kx~qBJBaF>uuheCH{l`FdlHg7_jZz&;}Iy&o9-rMX68JEudk|H z77U*~dbr@(m(l8qulHDgC)W#*oMOfW5QI0>K=@=;HQYRJh%h!OwKqZeC>vib)ps(4 zu@PL+y+xL(tJA{=ElA7nqKXnaosX48KUrD=bIAwU!Z%%qdr%y(F8&@w!R+)@!GM0We6bSy_<# zet&W`UoQfMF>#|r#CXM((%>x}^R(?4^jBGPq6BZ8 zn4vBWt=9=lN$Bs#AU6VOoz9SU>c;z>&yR8-%sNaB#A3h_o9M!*qrD{=<28%C;>g!| zG=Pmv0}RZvLq1+45WzNb>{{Pn=iC3NdUtbP|D&nu2|X}lz|9pz&gu!Vq=f4J#-B&f zO8+DT@UIX9KebM&Ke2UcUe2)=Jt(%b9_d#vxA9G&u2;~D79@5JUFRLF;t6^HV;U{6 zd|IB+C0VG4+K9m{tWSN{@_p_);naRjktO?!qO%4Sj0ecK@W{a0)}QmoQiiK^^U3d5 zOY*BCgok0YD-TG@^?6)=+{8FIt-wm!$(lk_4;tpLmF91PQrr4cVznkOx<&g^R-GZv zI9v5V$WzDo#CPZXA3#KOjs;8FtHR3kQ zJlEokZY=0pmW+)xT;=4Vf@0;yg|dEI`aiO0F{YwpkfN7bVGbOY+$BA%d?ziEy8(cc zt^WPzeBf<76JcQ{4)_#R5r7#>ic;ix}2sI80?8Zxc3qw?i{+b(Cz zKH;vNKo`$aTvPiU&uYAw^(EJtnpE^#;e;~fKto(hZFQX+%=BP04~c&e$u|TDTm$8^ zmO8oXbRv^G^D@+1LV@k!;!*S7d@OBm&Z=z+Ng zrjzWC8BHfWdIesxnrAgF0j-2(OwcthBMellVMWS32pdl0&|JGD+KD9wKoqeL296kk zw$&O*z5Qg7hJ+c)j0$SA=_mZ!wrD(z%`zcTZ$WEz{=CyI%b`?taZtDvV9gACF(QV? z0nOk?ey^KjEf2~cep>mnw3BYXjg|{JY4PUy!pw!Z#p`YV z{Cw*IK)qftbZeqMf)}<&_S>&Ayn?#44GR52s?SwO5UF~NR>kOWsb(8Oy$)(M4-!dT zeV7r?6Xy~Z??Rc42QX45W*6K>Kxkv|=J}K3UZB*0dL8!z{q3U^pN;P#%-&hkHU;9y zB{dE4EBDla`e-=1WwHKls5juQXsH;n5`>Z09c61k}g zet*OUwujRs#sPY7kxSE5Xoy3b0Z;+GLwt6M%*#(NB@6n5$h8Jnc(oRN8vprIurK7z z7Xxkp&iJaZloef2#)C-gaD&JA$9Ocb7`&-{%}Z)eG9eY{>DV65K}Gfd2E=EN7=VdE z-;QJ@t&x~zCVq4QF47XY02L6k8I-i{Dgr6yLtW+8n1}5?0~+N?o}0g}yLJJ=lEZwB z^TLosZ<0)7!a{FjO#}$bVRu;osrthb;zk~7?x?v~W!F@8F@xl`PU`air;@3MBc-`R zy~;^`Ah;o2n{rm#`dH(WV3DML#!?Cp=kQ3HSW}vcExUT&Uw(GnY4E-ZiATY8NP0D0 z+ue&a0BRM9*{5yQZ-v-Qb0O2NxYSa-$5J4sx?K%}8nV?-_X=c|UKBZ1DlxAD!`*kq zX_jg7-EM;g0kZZVx)DH}ZF7;zmeJZZm(y$<>f7_7Hb#&hv@hW*bK|Fu+}F#kwp+=l z@(VU|e`f($W}Dmqa5Ap2v_$(`mZoXO%D8)GxOBlll)w_ICrs|ubxI|GZR2NuvVj2V z;(<6h>SiSQZIFEatk+JN5Z0M1A}_*jaT~TUr)%h-(+DmhEpm+Yagi}&@@o}nMyWk6 zC9BS1?Iu$kaV3dPB(TmUE%%%uC?L@>sEIKQmU@g3Vrxd z1)I=D;2MIL;3#6h+A>jj2(>%Alb=gkg$F|VuiSH;=)GeeC*jIv#Yc#5AGAq!YBT$o z-m9?vT}zsLHDmQTG7IaQNcI32(lskyaD%#6wZetUt`hZ66^BX3M=JDP9ea#p%Y(0d zIk-G{R2>OATuqk|l!NwAk9Jfpz5>2ttF- zkBT#@-ZgzZu16YqYux|_PZhEz+qOkc03@U?__YK#=v4cWd!S}5RXrECN-&35*ni_N zyjtrJhUTz*+*benV$!j3jNCuo7@SF1z2H|_ z$^*yjB{qxcudjwcovxoP{@N{AfUD9EVCK<)g=;LOX<~#+^4qx;`7tRu9=+zA4Ov-d zH7YP}4`_TVaWakD8f@4;g60xTnEXEHlN)#FU;cH`H2Y2EZ^X?xlLgps zRz%swO^lQ|tTKe)d2J&f8njLaE?eao81Yjm z)nf*7#WsE`fUw}6)1&QDib>b?g=B5O0GcD0L9zay-xb)`0X&#+ykZgd_fNIT3z+c~ zuRlJvZXhmGD;l*5Yp($Wrj+G_Fzh#(zE9FBnCjU(?q5Q!~^Y20F)UjrHISECkaWvT8v;BeDaeLwUeg*3h@NM zv(B-jpTU$QnQ~s8vC^Y9{}fbVwgQ&9h{xUL!oOA{EQ9s~5C4%-&i=n-Rs7c;7Vux6 z-+v98|7&Q>e+`=-{_DK`uk-eQO>g?Ih5El<`2Tw0|JN9`|Ard;H`L(2=@kF}rBi^A z#A=!)su;LL0D~DGf}H`l+dYTcrJNz)jWKcIe)VrDHU5o2M!qxjRp9T|&HoxL+++Pu zWa5zM-^nLta!fvRBBuE!Kq#RH%vPE-mlkV7YgPT6*bq&4(`pH0mF@9y)17-$G1Tz;!5N*qC8$vS#&au+kHL*wp9f6E*8Yt7ksyJ;pb;HEOv z(18_y&0ncK=?|bmU7D<0b@P+wAj8fvt;o}BKrCJ-w>1@52y6?v6yTdOT&csOlo6r< z!ZdI=(L@(0z)c^5IC(H4l=a^wJ_V{^C|b1$laf{)!1Rp?{hENC03+ZdaNY@}9<3q; zG?B^}kWUtlmfMgW2wA3QOTGtdtMgHQ1Q*b_!(DmxjZw_@7;3pYy6ACTv8zA2e*T)y zPEQ>$*63iszichF`Tp)gH~yPL6{gYv(Lf4`l<{~0X4mS_lwDr>laHe{4#{|Hn&p8; zHaJ`A+dTs~x(Ar(0KYexx3`qGC)8N1V**AV4t0x9Uef2%_XH{c!v=ujmC~G)x^YI& z-~8whMz2XT$@)5XxfI2Tbuumy|e){ zD`%YbS7xX11)!ZqY`&o~tj{F^>6h=N6j-RtK|-p*41ph@lN zD{KILAXMxwc{F(941xuCe7B}`eAb`S5R#HwTtQ3#&C-Iof&m+B1_y?mKC(^F88ruD zJrUiaCtqz@d}`L)&Lj<#0)}S@tlYx}LJGdf1PCljI@Gwz@j%6pc(1!=A$1()`-}L3 z5&FRkP+*2BF8z0hXrq~OlG*4#fv2vL@mt-m_l@?{fxA(-mb1pJDd4w_OdqI>Xhumw z`vu?a9-wEh?M1Fe*NCArz5u6E-=I_0b$76Jg!dGsWp~$t6^dwODia|V&hT-L3ckI2NGv2#K(|3ua4mJiMM_~xWV9N7 zjLFX!PpAHHP9-E;e-5rz;Uo*|Nh@D?Ezqtk6hd` znFdC<&(1zedjhnfywe@Zp{K*^#%T(F;nkx>9g{6i5NpM>8=w22axF7J} z%4zRY;ltFwyVFOr{b}zNG$ zjwsZ=Y6}lovmEwfj;D)Y_E8DVo4hIw3;_S*-SF1|vbRUqO6$pHW>!fKmA0s`xKihZ z-5{}9HRQ+rnby#{BnI7D*6xG;F}{)SY_k1hWXT;RD{=LxYsPMJ#t4%ET?fAX5=X5k zwG6uaN(8{we#-Ay08bRWBeJV;ZFktgnQIT2-Q9xG)24@!KnYW+s>2KG;|2_)&dU-0 z&!-wHeu2}lyMhAb4KirlQmn7gZm9^Zj>>fiRf_g{@YRK=$IR{jkF?{aT2ZUoxK3>#de>9N4%H=FXxyMdGcvt zaRE&z*uAC@_$q@jaVx;~rLOeo<7gZpa1FDkHx5MtYbZDwxB?||xCiw|pO2s#mJjhk z;JxjUBu_3QnJu6C-3(xv%tX#NdaikI3)RmgKQs{sv}V&1JUjJguZ1j)4R zF6(Ma49_J$7QqspUEu3$wLfa`j#so*V+ltTBQn220udA>;hCnv_Ba`<#3qIE~A}993P( z-b7bz9)gBU_p|xY@`XVO;DoA5Sm;U7YlkARJJzWR`}1%3d7v`A8%xas}4E)m@cU9DCh z3`uf1K9}&`>vY~g@4XZ+{e;1B^&<_|`!s5yd(W1;f|9R=e46sW-?diAw0)Bp`ccy_ zDu@?PlOLzR*U+>(CFogXjP0*VIn0%kDqyRULtZMC0c9iwn6*4of<#pv9M`Mo+QuD$ zVoqiujF9uO8AlV&lJ=;0({}Gk$7W{jq(&?CJ{QS4_|GUs92JPcyQvsJnk{?u=AS8% zpw6?|8dc9jZS4mJH{*tA>ZPOwA^Qip>m-^XZHZ(9CnVHUt2*ZP`C2?Pjt_LN<#e4kwC@|-1}^0Xxe8+q zb(^rBp!~v5M0xvXDSIC1d}_BGvN~Ve-MSkk?)hCUk)Qy4Yx&JjN5KBx&8Ko@hL_;p zzbopl0U}zejPEqZ$^y{S`lamUX?RpW(fVl0A7QM@V3t{WRlN<((qSoXj(#EW;`FL@uZhCc_2wM-0ggR zJ_DEWz`8uR)NDzRjDx?7xUPcOe&ZXoPcq%FxNf7$D5LBPgqKn)(#EV$u!=j}4WRVU z@{61&^FHsI&EMm>;t{aL`G#^NSBm#B5>T(M7~n5U((>sQ_-_<&rb_n>)vHOrGGtp!%BaU0w%I;5)L% zUFo^j$Lh+xqR_+U@E5s<_&C(GG8y7I(0h$F1P(h zgXVh_28#qT;n~w_=J9e!r(7%`>ynncoqDRlTuL-b8dI7e`Pm46Z5fQ~EUQWNp=p`W z_A`m9Mv*>Umb}`zCgwr0k<>fe1{OvHUg_3rl=@y!*~0DPtc{J_!|mdn8bnK)bE32w zz^)%M=Ro%bX*h|Cs9VCpD517rC((ad8XEESXdB6d+dWuM?Ek^;Ns6DNbv>)g5q8_r z2^$qd?eNWy-PdNXXee3^IvGQ8uvNpXblHC7=TSw7j81ByvjcqWPNlj`WKw06mvYH$ zgOZ`aBg5A+QoqK0eKWOPyQRl|flC}WZ{Yqpy80Pkn>i0&zN&(6f2)0aFxo)VrC}S; z%?Rg-YONF6JKSx^SRlOf4cGXb3w&07cgZ6AUOZCFRs9o5G=Ftctshu_qg?u{nUNo5K?7~C-D+6f{WY(lv@EzxEVD`H1 zA(cD|d`^B(uWB=$cJWJ{vGH{b@#gQ-1=0dW8P~LS{O5^Onp(#N-5$8b`5JwS%x)PX z%}Krk`ZaGEbaSRg&eXT+M){;$r*;TG1r=(p@6kpNHJglV0OW`#S9S$^)^GF+qCNl3I>ZywWToRdCW>w2crDe^MT zW&}Cr7$zN!mv7y3GL@;yT!$YertUO?u#!5eA5HIv86ZBuFlhb+sx@AW=<`FP`F zJ9Ui;Oj{r7dcibT=&y&fzg@N?J|o==97eZaod5c=IcG4&kdRi=bamU4KCA8MkC%p6 zJkYkd5Ypup@WyR^`U3Bxa<{bH@S9ETELS*^`>Q{=s^TIHFYh^~f#HCGt%(b#c8@JU z?D}H)rP$gzePHWQI??6;3#Wi<>+)u6;Yw$PO;mPrg-zOl-+QbwlFY(>c19t;*=Sbb zX@1?zO0cKP)s?L)+|bI6%PgHx>jvDNErrd-`r=F798 z9!}lPjs_`TRJjFKMOv=USiAWwIghZ)bX%lVAAp*eS(h2%D~7gi&3m9P%pUcliE0=_ z(c05Bwdi^8@vqA!hX?Xt=f2+G4k zpSRLAG2GWCGw=;Z_Gq_hS6%#KvkLvunKL_1zwSi(;k46KVMn{4oUAPcKKVq%R(s`E z6x0NA*-|rJm*_hcy0M-Xz4ezf7+EhPGbG@YkX!A<1FLQEU2o>?oy|c^)Wd+MGm-Ss zVI+KvyYkz@>!g`=71?kGy>>sY=Lmt1u)Re{6909MamxtDZkq0NP^7=R+bJVrqtcOE zI(xM*5t;F924aBw^l9m%5VKx?_1wl=VeaH0WL`L#w^U*N&YRoYcg`=k_BAWvSr5FM zD763yx!Sy6BQCV~UFl!r2h&nbG0DQW=(ZK^ z&+e{cfDAAR&Bd3YXH^Sen!3hcnR;nbKB@!0Gs%4>8-Dfzm@#Vmr5=`E=QD4xL7?Iy z4>{F8@dLR_GL)T62Z(0Fem>K};Y6hX=-xrL6|n%NZ3Tn%y$#lg8u(XXw9_;1^|WHF z&|YBbxCgO|#X$gTO5ZdCn>R6-IfyQE98pqe-VRjqWnNDg_2Ldev)_DgB>oZhpe=A9 zk;3!CJ9RPTJSyQt>ABz!x-uZ_7fF12(m}+;dndUSAG*kV$ycFu_At?kst=TzGh`Pm z++;+gXy(k3?dH2i$#;zxhu1aI%R=eBib7{wf^(ZThwOORp%#rE#nBDbKMb+Jp7uXl z>JQJn{))WveZSnDcg>B}iVv&Bsw9_wq(AGnmv}lEdrSO)5!GH(C$gkV zy~`0is(Xd553Q=ez#BuL#7dkVGvlL-`rN-AYsS!fei(ggQb-_U$-7yNj#sG&65;P@ zZfY@N6TTiM_5C69GZ(4dRO|OM9LOTyC$GNHMPAnRJGM)livWX>+nVXH$Q6!*mzt-_ zg_g}VGx&#;6O+2LW@q!=6%$}CXNTRb&t zKj|}6`By*@Fq-RB^qv-92o*Y92KLtO z*_RYfPIfhDK%sL3GMZ;vQ5JLz5%zcp2!sW<{gymyU;O{q0X z^xlM?o_4yBkl1#lSN3+=4uBBsTHl+v=`}Gev_x*U{q?vf+&y6QwQMD*Ti?UA`1$n( zc-%hUK4c;JWfT`k-Y-ZmNGZih=05f6;Y-T+*rMr=4D?PRb8p^_Fvm9^;yE-{HhT7B z=gh~KR66S6a()iG$+lKf`|+Ue?-6MLnaE$cjIR?A9Y^VuF<5)jH?Kb14wm$K*znBy z)H}1*B;h-}x?eAjxsL15i>UXVI&^H(`rh0$R}jBtRsLzu6K49%f#UOD?ifmM@%c>r z^!fIf;hC>w%f`E2CZENPD_=tI%JPK25tH?!XEL8+2=BbqFTi?u0Dw4{CB-Pe*PjAB zO{ZVIdOHTWgwa@QCf+ok@_)+b@QBwj1pf4UXuegRJcLm6h4I7gsaF*7PBylI;*n-n za*+9zpq86!(d=vex43%*e|{rLn;&ZTPNrFZc-NxO+VtO#c`U=Ifk=5aPr-Q^%t47Zdi>T-^>}A=lrhl%KXNJ z5@V2YKteiN&^6y{% zLh&z-+~s0PlRAa&SQ5%ZLPUs}=lDPRh7>iuPL?@X`dV?4e|Li8gQ!NV*{cKA+$vEg zHv`|Mkuq5pz-4Ey$F&2uqqvk)iX;*rZJVbQGKntcmZK~Vg<3P5N$zl&F@K_47I3!h zCG@q$*B7r_dlHt~IuA7>*zlM~$N_a8OHWI)O+}ql9m;%RB|}eU9-{GS+6v7<^xiO5 z<2C%wlGzTL5 zublm>LYkB5D$C61i-n?1MU#5T?@0&-r7ONu06L)aw#f9aHKY`ca&_JA+I`X8cT1!s ziC!f2p4F}PT6&S;mRjG@A~S|&ZKQJ<~zus)W#q8Mpd=JyktQjQKaZmk};p3F5xj8MHQ+ubrPK5sK{>=O6+3D|r zOKQJva@Bg@@-cJmbOw8h&9}OoNwH5w%xM{!ymk50Eot9J^Z5);YFm-X5|5eQt|{Xe z%=e+sQRyPr=kuhFU*hUKG-Dw5@|tIp?6$)jzZYFPXy8FejFa-k<>&Tajj#L$*;CND z25u-McMyQ*O`p@OWB`%6R$P^IgVgXd4-Ro`d17DLQWk#T_N$(WH>?*tOi5GdYH(D$5(}4u_`9==fqPlDD!>!EZG^n zS$vj`|3ANe$R0TLPDVH(48&K8f2^_FBedQ}z#Q1UIFM89LaP+3U@ikKL3(Pv6^x}SZzvP2)hz`SP;*rZ-1gV9zE;{=Ib@ zDiTE754e?Bi`?9rRxnhTefi`9f{Sl|mt9UgPgwt^nGHX*CCCcpaJ>%YE2JR;Be?HzD9zosR)bWb_x7MB7aJkR6E z_Wst^szP($abJUBZ}RF&#}Q1^_Wr}U-RRI|2}^d=;qP@}X0JR*bmGrD%z05)STW6; z=bD@^h;L;SuLYJArXCn%h~(cn)#b^2$4G~gZXCeY*i}!0ZtewZF5Iox^8%!caQP;0MyzeC+Bq(ZE`nro%=?49OesKw!-FkPsMXR!=1zdw9PID$d`?i|B3HBS+uoOKFrdgpQ_%u68s4{Jb!W%Aa(-z$-{YiBt4k->7y8f*d@aYne|>30XF4G zg3dVIg3e$S0vXV~Ihy!Quyadml)iCtY&8uLF}R=Pu5rgUAVySNAy#tWxxP0OHA(FF zZqw;tJIs@B%jFl+!0|bjo#^0h#W7fZNqX)P1;6@l`t`S|Z!~;f?1rImUJ9pp zADHMSImDP{FZ#|*bT6MEi8GCSv*(h+nY2iV^O=@?eJ`PNOXQc2PTF;ye#v|8iMBaT zE5+Z2L&|!_vJ?#ck4z04*cIOjJFNS$Zd2>ZJoq2R-(2)Om#%TS2HW~lbKWltWUqb6#z3kU}#Puel4CzT;83Bmx#!T8#_=_ZFD}^n4IgXqa zb6i?I-~Nw#jRRT=4LPb-EvNHlyVL?oliR!B!?=4D_IDC}nwgbkPRl$zNQnAur^*d~ zNCx?rTmFo^msrj3-3A9`J2)u2uAz-eOgsmK-ZCDL2pUS;zPEcXD7{rh{U#hW#*06D z$#B#(Yq0{lbAe#`?l{RJvuSVi0&L(DD+9eo%Tavf)%}OtA#+hYR?__X)6I-d2L)88 zH0!37dwN@#z;Ci1lZxuRiksBw{cEFg;Uu}DG=5IMjzOfmjI8h>q~(cn0Na^{Pv7`M z9mihzFF#&+VZ%$O^srsLap$?Y-2R9C$ZtTcy~NjFd-(N>|Micje4`>Nk81BlLbEtX`3-HpM?j-0(>)}piOAB7BdZ2ftW#QW_b8zE2+ zz*$=*pmb)asl1D!pMNl=c5imaRhW*D4udUuZ?MX=Z7T_F6Ut~~*X%Lz{pemp|Fb40 zZH7lYE)^fq z?Rb9^eo9bNeY`n~wdg&MegjA!t)cLDd@aC56%Fm8{zT_z=eaZkV24O9i;gbyynQjIA%GJ+K7TqbEsmiY` z9apme)WRhlH_n^>oI#3#PzKf`y0Q`HILSddOclYRLE1)z7(k&b*KiW|j9;`p?f(T5Pk09V^w1#F!Om_185(3&{0y>{$LxJ17oi{qC7+2kK6t zOl0{dR7FPCyJu+JUBs#7f^9PxIT`wlrzm>GEbqLysww9b+IzL~=z^4^$h?LAPz> zi8#fY71kmUz`?FOx2xS+1Ad{cZvhASc&+I8YJYQ}u0Ivj7Co-dhDln`dnh;U2Zw9& zGyF&02r5bcwj=RTSuE>!C7p_s+w4YQ+FJo^elJD=E3iOOxw0A`XJz{BdH-F(Y{8#e z&hIYhr}G_64`rsL@4w7fJ8j}&j&7ePI|c84u%Z?bl$gL?n`#}qZKdsgNKC>QM#_nF z1R;7F)VOBfMPXjuztkt?toYAfoqoNNiO=u5mxG9bGXC@!d%~c#fyhCio*I6b$UyB1 zBP;iA7%K63nm=bRk{Y)eXM*djJTQv39y7*P5XtEBBgKt-lvTJ1Fl_EVPFSALzSub* zmGHrG;ioV-UwDp1^&a)uEc11zkdl~RUlmJRGah)7T3klZ@tnUmDZIoY&lUPf&Gct9 zxks$Ck3C5Wbb6oSzXwpL19YQsdl&O3t*+5#$^kKD#ucu0bw_k}q8H9}Dtd)q8-BfF zJ$tCN)leY)7b9P)+_$ti(TulVjXwr%aDHUiZ;9vl!d`!cF%Jf1JkG7};kBn=s2pTh zAMSV@0Cc&8TNc}y!aomQJf_T7(Nauy3YhZ0N+B(?TN{@%%M_|@y)_IaQtd*0qt?ypFX>%!cVgnxl1VY_i}y71`(KDN(vnJ#Akq`%hp(-I zt~&A8%5|gG9=*8)MXQ1=n=hXvS`Qia_H;V!Hz00xi{hczK0KE;xS+o@+1Mx#NjiPF zE;9uU6gq?XL|r49Bf957G)>u9O{V-&U7zEJ<#bIv=ZHa4G7-MX^`uXHKGJ$-4K?34 zI)nr&*K)tD!H!Smf0@sh2YQN)5Jzf4V>IKb4y1M(`^7uQ-!<;eDv!~!ceUSTL>lm; zsj?D~3m%UzF`=emK1brr7Z)4^_}2^w0K^tq9V(2obOzkfcAskTZ2t-A>N@+Dc&&L< z6g_*QvNl=j-dXYV!5-5ApkucdYV21CnH3F|<&FIz&iG()o4bMIJkDFwh z{{La`t>dcPy0%e8I-~>%0Rbry=};D((uj07NJ}Fj-67owNH6IWq(e%ik?s)b?lW)q z_w2pj^X~JU=l%YE|6}vpvhH=yImaC18rO9V8Ofob(~n8DDsrB(v?bK8J_O%d8$v9B zrTg~Bb-L_1n$({SlBL34MOeSeS{x`dozPTv@4hEkXRf+_l#91360g@T@XbvXUF9o< z$kQW34P0Hx!%2%0yV^`BmSAG@%->R_?X(v-1=ox@JF?eymI|3H(Z}Mx2>l^>2>M@v_=Rl?RYSoeIMpu&j8D8Kg)P~ z=W$C9?4U13ot!s`?OS?U=BNCyL?2k9w+{>u67ad$+W-o<*1(S?2@_zYQI6A7l0Mp= zsW0{jwNC~}iS||L=2<4V$k@W*Tg~er&L0*jyeRb{8cvT}cw3Sim5K#RR-n7!{~8=$ zU$5*#Aj&~2+*)lh$AQ^XZH{pzRk`>%Nw358XIJ9e6^Y~7iIM95jorP@aeF8 z;ODvTnQaE1Z2?%$O$%)Tp=Tb)KTuCdSq+T(Y?i+&ZgamLS^roT86gpQ4ZHHVK1&p3 z<)tNOUkyDx>9}ubco9x^Va)dCd}=O1H$s}EpS#I!sWGH&PE|Kte&p@kb!TKo6wxs6 z4XqLCt5y$uI?uGerjg5%BxJtTrNJe$CnZ!pX0S0ga>A@z(;VkgrpJ zSynjMZ&%#)Hqh*q0Jk2`I6u>ORD?na-yr>S?b5=E)vkt?u z5g9wx<+Q@(`h~7Ax!sn6i9zIRm+YpFr5 z^x%B22fmcv%P(Bs5TQoiu#9ub^0OI%0{?iQJ{b?qF|Pw-t()mUK92boKr;-o$a86X zFkF(n#d!JP4x{XPU!y2l0e<8e)c$Mo;6Z7vF_OOR6KYW9UIPqIa_a7l+8R#Fb6G~g29RKc#z^hz zz)>W$0)PM4fk+V|XSH*lU$ z^Hj=tU!o;}K*4uEnocgA$fqF4(k&12pH$-NEd7eQO}FUVj0jV4s_5ymTk6bFqFaf( zX+fd*m~PvOZKLI9w~cDIpNqs+b{A7P5+B6oLOQj*FY89K7l^z{FW!(PsW4yyjkfq? zM5`iAJJUHz*3Uq{sumT2JFWJZHCMPaHPdgTU1~?1cqa~UMM({Hl-fV(>wS~1rw9rx zZO8VMkQ!s}Tb#U12J}Qd6H>}qJn1J3KZj~;bfw44U4b{#XCATFM{>cHBLA>;MLz_* zJ5sDE9N=Z9V9d+i?RLoqp}Gm~8lGtVc|+?4>FVLattP6Lke&C`7T@oCw8(VHQeZVf zz7qi^Klm-!-Asb-nE>aM1nw0hBxEj)vrJW+xpx#H=3g<&!jDGZOVER~a%%tEezWFC zpqb-Q9#3iG3~Vnvr(5&fqbQT7YiQbKDcn4*;uoT_3jk_t#57AS;znz9s|Mn|Q=C3o zozoijL=Rr400Soy6zt7#zpUz^yDd_GLloJ@^V4O6ZjXT@U0Z)F&HTl|U@C2bK{M<+ z^htS8XxoU;NApLO8l`XcI_~FL5iE!ktvlCn3DslIVK@ocg%!+~BV25yGmtz;i>vv)Gfj7; zWlvGjssVlMmbiskf2h2P{xS5~*3=mf{s542H^TeQ=UaSdk`8XI_Kc zh=UJHB3(MW4VI@}E#y*=2Y4cEFP-CgI*=G^yZvkx8D1sxzm3iPQx+n>dE^~}d{&ez zj&~y|+V4Do*+SqOART4I%fnWcD>_}qK9H|N&GhHZYZ z^V%eHZjW+~sZm>tz9Z7%wF>=*N?jg;ieQ;9ijFm(@%fv+2f0SOg7Fo_6XJG)08pLV zCFQ189{w1byBILiczX_Goq#O!IL`t-SV7ekipYqG`qBM7lt@Nw4%@>!Y?ptp9p9n{ z8;KFqiCz9X(X%~FIbAkkf^$pD9$Q0E-8>3V?otYz2f#7oSuAN!J&Kmw4tH2+^&Wzv zXmi**=upu#Jhh~?LTb*PaZ*@u&U2Js^Bk%F-GH9#B#M&D_c4}j67Kqlm=wL7*IurS zryLdug+tpt)25yRwjg2Nxvy(gHEMvlMo6S)MQ&D_cF*;ynq#4S)Q2G5!^Bk6JSXlY zF9laBNP4)T3d|z?MW#o`Hl4)|AQx#}0#vc~btDlh)$y0Qz~L>wwN4K|G8Av=i3EyH zriv}>5K|lma$6Y%Z!FY{Xq&vg&8d@;x9m>*h8zWmj13r8;uy%0OMcA{P;Hx9TP({a z9tZCPdS0E^aS3WP()7uN%EmgQ%{tazt2?RmlvS4jJJ^qxFQ&|z4-$SjbmQw=DdxM) zAU~4BgKzJb4zX%S!Ql5e>&D30@9_2%uw{mQXkP7U^-v%R37a9(Zi^(vyx}_Xo1q6+ zh)SZ>g3}LdezpaQ5Fp84oJwJHMePG%6D#=(_$9GPBBdv=eBzBWx6Eg!ux<}z#>SQI zjefRS&L7krl~8BnApXwhwBwqUlJDl%a)ClYLR~5KQ|u9cLJfM(Hx<~}n6&MC@>%QI z@h+&gk~RXx(x-x}4VM;Q1*Bz|@32dE2?)9>ArrTWBQm<{smOJAv~RAkRr!T*NXxP^7K{0Vw~ z)K**C>2#a>vIehycLm#JVg>AMm)#EYsaOR;T!XluC?Fh>RPq zLylQXcLSI?orVC4M~V7X4yXY!_gM!6S5{8)Sc1AD*lqi;!BP06%hTHyB9N85zd_0= zLNax@_v2QVLaDH}d)wo^D}Lc>D|Cdhk}t;^P3P8!{1<1Opl-+XDCr=oPr>q>D?GB- zoI3>5LVOQ3D$+SepRcH_Kn+>ONuM-(RTnjHtThvz;M&zn`wFJAaICh;wq;tif=S(3 zfG+%c*)ssKS1nvlgbh~rkmf}04z#NHL2Td{S%SAwUqbUq2Ra)-Gg3sES67)=ar7|t zj;~jlw{7F4i;11K$8FT=#wYR|?%2E6)s>aH`}-xV75K|PT)Af)hLbBn=gO{C@QNuh zf%@kgV8Ar7-DfFH9|ZrN&$f9UuN+WXLeglh<8iDex+?kZ;WTrcmx-hU zBo*!+ud3sk!+yCxU%ii;#bQcPxB#CIWm0%m3>+;vk#)l=04}Ch>^;8U{yj7=_&X1* z_en*&!~;7&C1%ZOI1R;ELwiBVXc`U0QmlaGL@C^;AhZh}sUiJN?d#3%*Eh~Dv|f$5 ze>`UvBLA5mcaY0Sl{OHyRo)`X?|Q_L64{ZR#<~)GgH78?b6v(!!b!-j#m0njw5zML zD^crPU)L~yqcjTF^KZcIKSd~lpYhq%m>`b$H`^X(>p38~p7 zSB?V5@ZnJ?m5A|Yu>K;ylVshP^1V3R|q)bHkFlKPEfDhKKZn$n0)&c zms-zxeyyCV$3?s|LHhEb%~x9ijLEbD@Pcq`qbkTiHh;?H~K!=J~vA=8O>5X5mcJA78})DMV}83Y8oy* z7TjvL@)n3Fr|(;P;TPG%rgoSQhR*Fm#)@)l8UnE>HdNM9P6Df?d%5r=bx0Jmd_toG z1uN%$FvPn>>cexqlwM^?hT`qLh;^@&=)U*Be}xVK?haW|+&a4yzNiMG={g$W)b)5x>)_}>^9prGsf32_Sw@xq zOBoG`&&f(^?i_2~Dc0V%`@Zo^o!O(JD+Zg!Ib1Ks-DqW&cd0Ynev9#pB@~a-d?e&} z&@!!J$T84VP~}U=aaD}R;x}fghO1hO1ow&0uW87AvP8qlXwq;M_Xn)=3Ter3w{<{o z96K9W<#0(|yYAehn||+}C)zDx9b22|Fkp}#c+?8CSNCYhkQCXxtYT$YV@dh;5b)J! z?5Q{YdqVRLRi*`z|J)^?_o8k!^^-J1K63+!&bj%zKqhvnj=-C$=yIiDn*%V_PNG}V zZTw_7DBQ0x!>Cy?)}_PgN+$R_b$!&+3F$ctHl)DKga>R!L7RbAxkTsw_zUl)lQEIQ z@NgqF>EyqKH(s>++HSy$b}NNHXW7i|r^ogh^@Oeiij|`kS^b>_(6J7|lfPx1W|533 zY@=VOHexy(bHn zI{hk7arx@p<27Cv1pIQbT_n~rGPS6d`zYOU)SU@)#O4FwKjpX+fRY0|4S2Uk7RmCH4D& zL({87u^}$p%Z6Bv$Y?yoAjgRiR=;0nJc^c>$3?h?S;?O8P;nd!MwMY#UKi0?cZjY%TjZNmB-vh%U!|@{fz9~e$ zuN?HQ;q;dcAYEi9*5OcAN2I-j5X4+|k_xi>LC^^quO^W$9rFHMIREjrTm2#)oCuAn z7bi+D2Qz1^^Tq0`(DTC$d>U)=Y%q~)z%|~?3av*L@iVq}JBDU#t#a7$>|2`?Z@k9M z=rf;e;H>&xInfIzrXcMNV@9n(*!B5%lX&-`;UFmY(@0B_?s!}rhey}3X=^|Iyq;xJ zu0ZmSZHl`Hc<_9N!&eBaXNZP3KsFoYwr;lSPNQZXpi#%+(gI8t!_!bukz4)KcZo~I zGidgr%_X&(^y*UIia*)gARbpVr9_wse}^?DaVy9<(aFsQtY0>#pLne}>5I`rm2^bL z-|P2CjX|Wq?Sa-u8hZorZujg#z?|P`e+(nSnmitJat;WOws0{0rbdBSbvp~Vt>B`e>=*{~XFl?TY@H$x@&ii7E1F*6L;xUX-giRh`;;jX_}Oy?&>Dth>H2Sd61%TG zB1;*)_9YwhF=(eLyngx(6uukhclWI2Zg{g|q3LD@&*OxAk;O2xL@K?bR>HYwgH{3J zvWlK8qZB+~gbkR7f7^^}023_VyiQQDF0FbJQ2-i|SyH;35XKjb^4})RkupLYcBadCgIN zW>~Q}(Fk7H6G-wry&fb%?%Ix~_DhYu7j&v<0eJf*OXrK580hLYq@ih5DrIA@Ro~_C*e|$v6 z;)+Ma=jPxy;fISU6UYgYII%BpPn@X9NWR?PEiQ-JukSI06CW=C4Q?b5HL9Fwg6gTr zD5vrt!E!RN@`_69FAe~y>v`eB&@+kN|C9mA+15|;He04zUE0S!vkI|YRZ{gRRQ z31At+bloWySi( z7`YG=-~&HqG%NTN5AM)h_?V|4dA2wSIu6pmH>7vE)U8Z19mQNQgF-4V=p-Aegba+r zeEwU@8r&p;7O0Y?LnIMBVt-r25By-;JSS2nNPQ)1lOW|oK%JTMPyhJOn@F%84g3{w zQOp1JS8}z$25`{2DD;Ub-1vkhRci34^4@=^!u{}!DE`t4aIuqYOk7u{_$I+mrhmOX}zn=8Jb4f%$%z(u- z=eCQZ`-=vp^b~q4#U9EnApVRY(x|cGX@vawr2q6oWy(85#x2_7{l7jzuZIRaURPDg z?`2E=@Yk~a?M3_-x}dicK8WNq?aJwYErCoO6lH3Um67O_hc)qI^*SWY82)nZ~T32Sor+DHpKpkTJz`QPs)PTaKcx0{p%02tl)vi zTbfk;y;)=x{&%zdrp^6dZ|%-o3Bp+=~R@lxYmHX_){S*!Y4`=r(Y9h5b>vJY+u! z*E&eg3cDRLJdT5r?>->v2U8UWcNf|$A{;w9gK(*})R#R#?_UfmGF9&>aSt@C9aN3$ zkREN059F+!P1QT^e0KH69Rp9X4&c5>6>Z1Sb&wQbVatzA$;0+UD}f1Qt$ppy^+h(g zxm?QmDU=%@csE?o$S}1?p}aQ{xWqIYJ&ce1Pxt(@#ROeXSBF+KIwRV2FrGRa_|^_w zIZc}935(1C!5KN3*X7I_J&F}~Q zC5y-jnq}Kk)TVZ20mQk#9)ZKOm@D81d%)LPh z@nz`We^WfqdjTJQTB*v~oJs%33%ZS=a(!IWl!JW1QUjeg;=La?Kp87$0WT9L%CW5v z8!$JYjh5wK{qe}Fxzj!TMfjW_8vc@7Q!sv0MA2pA^|`JcqbG&up?RPD zKcb_5yA&s7o6qw#2j{e71LKCb4>k+;fp5iJNAOo6UhG6TuTM$t21hOG_9(XQ;8}T@#T0Z0pldj za45w+BA{0urKuOohRp6)Fzi8kXxa~t-60*;if z^T-A4oJf3q$yB0!CXN1mFacKj?2%7gvY}~z1c{zTee_-tzp#b;O>CVQ<@zeh%p<#QSvzYbuDl9JvF z&4rh9{SlF%jf6xQ|8m7s3EYIk7$Qz`jq> z-V89WoR1(+B7;cj*0o2&6Q%?tmPTQ$58}(YL zP*peM%Tl$*HU4%aU$M^BRa>iVWw%jsFQcKpYFpt&XDw& z#S@6|HcfcVzJEwOLAMHe>vQsW=WP$_onKozZ?BHV+uE!zKe_h|G(tTf8@E| z+b_U*CEXaerY(*e=40T{DJu6N2n#TA0Gu0BUV>O-L~$(x^YVBBFjVCAK5GlQ_`s-& zCKESPtH05qi>wA03y;#>&YlKaaPQ=vnV_kBz3;25;_)Mp=Z?r*B9LSnr^BiCPzG{i-%CM0NjBq9?_fD@~p`b1wt zg+#Zs9AqNTSp8NOm~p*YI2dUIKm((Npne5~w}- zy6k`wrG-lFPPp*gAQc|va{gwccD4x|>pSZg_j6ew>`~Sg_{IkkYqWYWeMTsZ@7mi8H_K5S=0HBm41!O+CgPmPQ zTsB|>HnX#MI!M!4W&-zSv1!#>i2O2f^n3VDrtzsyYr&M%m)Z1PH^9B>^u>%DBO;?o zSnIn0x+Ku|=?H(^_09@eKg;PjchYJyXEaX;@0=IUwL>jNot}2z)0Fe`y`K&Sx3lxfH)P6y z2=R^d_ADKg^_|oM3%!?vcf=O~<6qzjEkDQ~hut6xCDwrBGQ+M;7f(SBeMeEXr9g4D z5&Y-hx%w)GUP?0&%z-UZ{OMurA`F*XF%o=q{qPZdh>$141d(y17xNl5y1bRd0>1FG z0WkKB$i7XlfyRq2qNIyobL{|U*8$RavLVdbEL1-+7Bs@YjOpNTfDFZP5A^%C8VnwW z2)CPd;Z|^x;M7HtBg5=7suy&hre5a( z7%{oTo6}Vhx}zR^xts5k7Xv`Gb_ii|vtC|!%l|VrsU~}H&3^PaqK3|}`6A$L%o$i#LSLzsHJv+ED+*%f-~+NUG(qE}-UNwSIe5wKKYFOa?`CL9W)9 ziiTXufC_tqq!pnKKLs$5XnTz2G7L`we2Gwytz5OH%K`Ei zNnFtQL^Runz>j7OxqSN>qF220*gfM$J%5I!)-~i4A6I!Hh5N{&EGqm64J`93cXl2T z<^e^=K;9o=jpWC(Z{T>aA`G)=IgZxMert7BCE}hw=7aBQ4fIebrm`;bZ3Qm01?HuH z1SD?K8->piX5oiHftMHGU?dBUjH&#;qD1c6nae!!J>d}ePUx}|z2kphXPu(3-g+n5 zc9I1Tyqbo|uWcbP_=uH4$bT2_{rxvA#yFsP%1#w4KBNXHFt5!XaAzykw%g*nP45!l zBX$W=?KCqDPz>vI)fZ&@IrGyA=$^+!iyo`(g$Jy5f6RSn)E0I~RD8uiE8NnPub6Gn zAf=$1>x{-%uGUC2Y8m2h8}cZ$Z73JR;ByvBtUhTjS`FyV(#j-8xp{vr^j$dtB~{Ki zilHI(ZCuP&{}x{qQ`)wJQ0=#vrJiO9dd6THpo8Ol@`^00P zPo9Bx+pq`xl{v=88uz+9z&e5pUm9we+jurF1(Nmd`z&@HjeE36z5i|u${5G9rLuOf z>ZozB2W-iHr2?qi{igEunC$v+zTS#2HoxT5Pmp+9>{o^#jcE>y-fod0k^%B7W*`J{ zb(yeAd51YDLL{Tt6>EixSJ4WH^U-YUk6WbTj<Tj;aRq4o=F3PJbzLMVu4C|VrP2dq3t^8{E!o92u z+>;71^1MJ^fL1nm30YqN5*2si>%h&Byt@MT(tJDG82ze(6~5}*sV#23bE&)#_Z9F* z9Uk(_EYw3PAsxZeS6x49W!?nm!1m;&+4Av{(WmjH383)Z!GW#izvW^24!XM=$K0=X zPmmbZ)3bkFUNRKy)-DEuYIJW7n_4_TvRG1fp*I902Nq9=w$uPuW$PkMn@fYu-4vnC5rp(=+nr{ zeB~V=+Ot0NmEDY4S?F9m+d29^d{>xe--Lt-Lq?*VAJ(%=-f53J)ri8A%VarhoG2ZEmEjX3Y1q0=`+mXZfcZUK#{Wcb6Xz6 zkDCmCo4~c~p$Ab9x{3~WPp{<`0R^GbiJf8 zX*NKzkPeEy*cIU5YRWjo)Q{f-O6Bx3>xz>!*&6@v?bx(Z`+OxoINTmkCnolt{4|~| z6~=z$IqSL!))RC-P+!p^_5y%GZu%bQjm5G^1HHBUh2((6$4TaYkJMkGs!=#93OF%F zs6U&D6tw4jOsz&~DxRA%$YC)VNOcE$qc`X=`2KzPJ;0??aLhcPN{xZ7HV5e*97w^~ zeI}y~U2lsNbTHPpWB5_@lpZaaej3Mzr@rZ7Fz3TPxrfBQ&dMZ5eLJ%gxX6&QJNlJ1 z-)V)+3vyKrJ~nLSHF}x_1`>FE&Za7PHqv{d=Tn1+v4sc!E?z4ie!o*s^cIUzVWI0< z8aLnS_`px~(Y6rc&=dkDWpj*F5rP(Ua=)+?Sq!5KvlN*hBa-xtYI}J;p8)`C5mIVu z^PW4c7zdu{g0i5d2p+qoH^LRehu7e7yRh_=GCzwVRON|`BN!{I$os%JSPQx^Q~fu{ zIB|QUN~H8<&KFQ?ecOhaQH$xnL{pJjH->LAiR>u5SW;qbwZfJkzBb5^t?IgGXIWC# zNpkl#0==lpFOkY|Gy3!NvXd0F{AGPHaupIQ-H`vE6%c{6Q3P4BSul8Nx=$J4JRKN^ z+{Y9QTfUoj20~<1by$8DxZGjzUjTPAN(*@G69)z7G}1MYV2y z$-M#s?Q@<)0GDC8m3zT_lay);ByV&sA@E+M$z-ajBtOMJ%f6IHnGbGjvsm@c?M=z| z;>0&?*o_}@8NtJqb&(glw~P1tJ+|`RF9b+k6OKePHa4{%s7Hpa*hf{fGr_kwDaDfu zWdmGJ@B~e~JI_`Vvv(^~2;lC3)Vrr;-@nqM_qYtC>%8KY^MMKv=P8Q7!r3Rz3_vEy z8CLynAc8ATLYlB?)qAu+B)^B73XMhoyu&r?vEZQQ!rjV#&y|I9284?Fv1hs#j6w8g zPlXb@u1pYKrF#+P1+IztF7pmn#=JbeW30`i zT(@Uno|)2}Il$0X^!@O)S)`pUD})W?iViH)OVEfVB@}%4Jz0@AP!+#{i>!7EJcwfC zV({c%=-GwDgPUxLdI8Ybqz|+J_{NT93iUHIGJuy`Wq#1Kq4C5j?KY2se5nDl;KVp- zS>vzr-p}?u`|UBR*tzA2zd<28iLt%>srYR8K~y z!b;o-_9l!%fS*sr;Vxl8q23A@QYAmYK2R?YNMmyRXGW?hxsH1~?7UrB#-r7Ov%rhx zqXedeX3GJ}RxWTgu)Y8p`J$}|ST_HHS8$KJBpmsNqdM5e+E?Ydh3K1!HDAE=wOW|x)uwLV@Hpj%qdPuCJkVR(1L*M2u|1p- z8F0SPrm_)1Da|q5`OkEKKM~XqD0m?AWE5o(a|M|vl0LOh9B3L06MHAhJ;H>Fka%3> zL_K3r>|<2O*Vxpp5UQeJDuK?3{pSpVF!f)|Y|(3IDw0hT2BCF_L|yx>AoP>$fLaw& zP33bWGpp%`U|{giwcX|IbIOyROY{K42Xih+)fw8Vbv(gYOJ&R+UwHdn2xZ0O1oYAJ z_Vl#5n&TSID09D}LxM+=H6M(p`N2M!V?%3b5uom&w@osRw_9QPB$tkJ2b&I{2k$}y zgPO6oaNt^CU@ZeiPcOO)E}M?_f<6*9WxN-Cg?8wm0B;%znQsS8tD6g#bQr754QhvqKeNguks}us5RsCAZS7IS)2nyZ-Yu zafBW9`UDhg^3;sILjzKy%?gb0uP{N8iR zWKrYU%w544f_)?c5OTLc*d8&f0_-E_$DhFATm~GSA33Jyfem_~pu)`@4CQbUQU@(g zEq(~(fmu)2Cd2#m{0BoeI9rB~LZE^A*apoYie5r6UP`uLi zp8yMD16A1xU_m@>=f{Q0Uyh(eNcQF{(1)3{Dg;<1(oRDscIY_(u~QcGYpenc8^V@l zG5kv_Y&{PAxS(# zeIBMrtrI6{p;I=WLX+CMIRd%@o8Qt3MqBj^f{0C1Sz)o(8K8%1&Wn|ian0+y5!Bku zYYKV`^26*jve}ul2X(LCuKnNv$bH>&DoX=iOqcy~i&u}np$(U5E0_Ab8*YyPLIP)~ zHXt9~Zo0YdTk1rORBq7ZN%)@97ey&z5HYN?o7(pLA{kBW1c)WP6*~MFmD?YGv|9Cr zG0K7@$R_tTWB6ZehCd6HzkX=`NU7&BqJB00Y1(u}Cb`1WqX5L{>^*?D#ay!pL(mYb zxKNlV5rHY_rONnPy!Q^&SmN_Lykx;Nz`2V=O{+$Ew{yz-4oc@LC4#q~4Np5NF|r?_ zi_pU>D9ts*p_?aL1|WPmxWu|E@3y1$-T)|U2~-jj9YQzD9uBo$el-uv?C%CCWX4|_^m$6ANqAOE1g5Arv37|oZ3Vc6=CkcW z?sx%E_pzfzA@~+bz({0)ryOjrJ^2uVT(((}WinXbEK92T?Y!`J%;lpbA5NSnMpQxc zLqrvYW3^Iag4kqBk?2iiBMIp@Wk^NyTOC_c{L|>9iw%i-O6l>QM^T)V#8m*AS9`WZ zq+DprW)HAk-;NlL7p8tYTCJd5Dl7(?a;KHosVKja(E!fj?(&WDs|ZYp)Zd@anJ1zq zgauR5PrftS^i%LrtpsS_Cc6X>n@C31yet*wb5}$;LE1q@I6+Bwz=>9yAt8X<-l%Eh zhAlknwa%Y68_!Hz)k8WfnPJ`eovmijTRG=Oxmj86$30ic>5SVmFx ziS*ODM=O3Pc=-G#+SyJRDzz$&Cl7P?S>?rG^0-{vV4_6vwbq@tHdCC{Ij(|B8kS&0 zM|Hi4xsS@CU)E4^kR#XcxCUHJY)&ouAaMxod(Kyzf_oCt9I0rsaw&&FtxYJas2 z96LNiMV|DNoI0wh+sSsCn~xx-9WZqY=;#DIbSm~9ckmxhjxG6wZjT>+@k%G!&M!j4 z&c{fXka@76`T-u2T#cCUHo0(%(?A@u$4oeAvB{fO8MHCH?AYo>Ui1uspe zDwPh%M_;#~8)N}$8Us@(o2#cipL!no@h|IFr9OIHHlfRFHOhYO&{P2pE}imAXdCLF zRs-d6d?jZ3Gupmvey9)c1_#>qq~9Rx*T55oqpA$Aa2$@LuIv{eL!SKssEk9O6}!N3 z6%Z$SeUua7GSZD9Q($U$#n66v8kG9vu1|%D$lyuX?~G-tI9Bf8bHM^3p+WrDk8u}J zIZz>4>LV2Q{zk^%TLPbSG{bQlWG$gyW9p`(MAU5sx_#@xgM_NrL|0ZlIAIDz;z@*b z(x`gIdRiSR0nAztnby1~I!_3p<)DBG&EvLmqbgw|;@8Z&6%i;`!0FT|*=D)>5}km# z)W}EEJJ>`}w>7*h*5X|a(YNk~Vh5?jU-?~JjIy=ZtRXgn%a;|c-!nzhTsrSh5E-7T zpvh=^3yt!Wi9T#|z||@N4cSRDk*Fww#MyxNN1kqBvsdRw_C9y;rP4Z z2)dsx1a<*trM#>IHcs)Y90|>P3m%a0!(qX`J1UD(C*ecR|xYFf;=jo zj{Hr$nn5th9&YDrQ3)5%ZVcr~XfurT`n_1FcwF*L{Q@vYPKd}~L3Hj{?j2P^j{+r< zvset;QR$>gH^rZ~lcB!a&DS4Bt^#_UG&6Y@&iyI;fm}fP=e{;wT#q@)DdpKtJqryPiYk6H#79-2oP=339iV3Q|^7GIlw!>xIk#kf0YxoPDW zGdxu2nqF5&GoHCvuOGci!~4p-3Dn?WuE=LrM3h)Ne6Y!T{=lhzVl<<&!M7Y?;0f;u zq}^ncD)0bxe5T&Y4QNegP#!oTfV zHZ}E0Z+0um^zSKkq92%kQ}-E7i-E1#(EiAM9V{ERq32St z`=Z^b$Mmd^C4zB2vb2eQ#04VMT9@Fi_j1Q^4ecML!&CGmXvkWvaWN0I7#e(zUw>)1~@2MNJAZQ2>45AbNHGU$nOb||(h z(}{|*Mg4GT5d7$IZDJew)7a3R9x17d?lfK$_p-CDqDXAo&}L=Rvo1;m?~PSeh+FCh zFN{~Ei9VbkqD?nJ$9SZ}CMy+7R7)1v9;Hjvl!k)8X(>u7*RG(r*#AgMR1%9|jy^Q2 zFNPsz`U_^BQDIhAK4({9wSRIS$;3RlWJA0ta6$;Ru3L$x-_siu{CTbzW8RqsQ3$d; z=)^D1(ViPT8s-!WZl;oIN8L`QCD)Y7vl44Ja;pw9m8B96(NK4y5_LGyrqT<>6O!Lp z(_#G9z5mXNjZCp^ynX0CQ!jsp=HP6dcra0)Huc-Cee#bpk6Lr-aWn5JSM^2?9>*}C z%&=D^t-Vc9E%inK+3&m2uhLwhHh`;^$3GPA7yoA;@E?EpFND+eI>c?jbv6UTNrIyb zq_O~_FVW4R&39ttXcEJByFDp*yr}O$@KxXk97Tmie;TQV6{F7KZ=Dc9wFUkMf#2bO z%EbK6j;1#O8R8v6bij+P8eIqA7jo2xKdOF&iqQQ^Z15ZJ&I3T#Q&wd2Q0{-WzyPrn zA|V}vSz$qLNIK9W&?o&Yq>bQ8rUUd;l=?q-aCrYy(yrhDB`_d?YUe|cO^5@IMjn0< z%NKklX{OqFAHFqFM_Rs!X3D?^c-potNY+wZ4G5pA#xw>p3r{NC(+CmXZ^>h1X# zgQ&BvK{Z=@kP%Dw2t3i%A2edz|CE{cFB&m;A{K~V23U2)&oci-fbZwG3KDdFWZO-X zAPTbS{FGxl_?&ao#SE%quX$b?p)6>^FX&sTK7?LI!yhk$Rlw}uGbMfxpr_z|O1@OP@?*9q9TP7V5^z?8%K-*2RMCS`+>`^uJ z!W}eG2LYit3yDze;W;26DlstA5dI%NPBE43yf=zCgdO&J0=`5fr0q}6f-;G9&~dlD z_^o&Vii?tGv!kR%pQ<&$@p!1(2be~}&(MWL^*K^R8!2CQh-Nx&I<5DKvr&q4K8 zPnKX9@r?F4NMy9XsA!nytb-JA$H_7E9znF8@k<5Oq`%@fi}7#T`sY7r>-~T?YamWQc71 z9877+0`!5pS7QVG7nTknuNMbyB|i8?QD)~Gd66n=`u;(3hk$uRrvcao(knTVTQnbR zCXV-2%&Ay%@Lhn3#SAZPWE;n@3_erkci};*&?k7GKpi>=oUO=ZIY(c)4l-^f=;h$? z(-3;gAl%oo6>okR)Z2|zTfUODG^!I>N>;9eX)?a|aEL_{8@wbc}xSRkP1J54{*CMKx9_~lh0O&#G{-WRPxDE*J!!tln ztLi5`$ijkPTr=Z@;cTJ#*T2`0ng5E(4me6S$;>{{U|VsWLjp~}#d$b0fIdeD zy29H1{8K#h@9t2E)_CfnBf*N?!J8ihj+bPtGdK(Nb63_Jdd=_5K&SjT+rk46|2y}u zM~$~0>mcT`AY|$pe|xqTwp(bh048;*`k}B0?A{yzLcr22)Ae=XCLa71P5vR(_as4j zwIrLyv)5Ftj>~}gty5tGXwqR2Bl+15BPR;&U;rXx45~Xu2(SS~!in?$y{gSdFK(4| zkL3koy;<X#kgm|IqPlfnt;T{g-xyicS)Xk0_8~i5WPk@-oaA){A%S!j6hb?%Q!A<{a7 zSF!?p&n!mrwPD^U}pRY(qo`RH$a4kS_|>G#2rRdwdwHf>@3Q zdR*;2H(=-RiY;I4f&qqh!X~^yUHpwN>P5iU<<;_E{RysZ`)S(1SCDw`5X2G*Fw5a} zdyOdXog3bRo1&|x>2Kxr_Uh1piR>7;=iSP7i|xUj2+P}AJK{f&+wtW^L4Q0#`pwPN zi3ErLId$iq~XTCC*-n4utmZZGwfbH#~d?i_! zJo^;bPK#5$Rj;a`P`!Yrvc1vXGfg<}*Lf8OQFlLEe%r|^ca-_pHvDjb%LACP*#PSO zs%#16WXQlD5TbFXJw6Ai!Lt^oY;81w^TEhzNvy^fPgum-0Mqg(R}+{ZDF&EJVjQ5K zIl16)1|Y4v_hi*e;CKb^_GNwlg7)evz^dbi{&JQTC;sl6e6QRT$~+FL?^TFBs48$| zNMUh!ka083MV(tQ1;%C$Z51*C6uMyl7g~m^1Is}=?Ra?fu%o@v1FU>C28UFi&xv&%t4p%l5v0=Gp%dD+3#7@xt&!lu!_t zAsR*Y=*Ig!aE9L}jHZnMdc{00+3l_WBjHfXA}9%dQM%ajy%1GK7y)YhK|#0@8FH1a zK-u?d2T5i6_IL;?hsZogq#S)8K%_zmOs&hkNUv!bEmOL^@^jYl7ME=42&d$>1WK^hEi=b>sPK_g^=_Xbs2o3fzE1;}^22r;b`H zCr9HqA)0#^yV>%Fp6EXiqj3%G-f-tO8CAY<61CZhk>@9-BeS%|XZvyy*> zbHPaJD2L)MFeFgoTsGQ;p8?io=hJB`7uDAf`)q9lDA=1GRGdZ*IEF&tUu| zHcI7+mL7%>n&u5_v@&_-c%Nat%g>&v}>QFTJc>4{q#< zAYS~(E_@rDg$0NDo{BX<01R~oWeraf9*rC_K}YbW^+}&JEUTlNe!uckv@|+Cj{tZ5 z8zSD4^`tevzPP?}1|ZmTJUua>GLv4K1Fv!8s{VM=BJbjg;D*!)bUhBSb2ht!zMq2g z9+m(g3Lf0H+KVuoFbKty6SXymd3D?%%jG_@55&U_-87ZX9Sp-~wVDPxC?3|a&fr^} z@{{HshwFoTP{}ZyC-Hn=I(H$=>wGw)^U&tLKdWJ9I|`YDLZ7n(6^HZOAn>jfyZl*H z(g{t7>WSE4_P2pL{fy`Jq5Gw8+P`=p;xXSZ?mw`$G@SsRdfIiSj~8MG!+GnpdLz|IiAg9&cPi}7atOx72L@~9x!TBJSOGZ`=B&@7ix<; zJa?W$H2FvGIl*mUY0`==6of7;fqlAVH`zHwB>O zX$0?f{tLtZhrPcHi@NLjfMF@6l#=e2E~Q%mX%H!?0YpR^q@+6qm6lWqL8QAI$sq)k zW&npA7`l1(c*b?!*L5H7``lliPuFqqVGcl?+5ffIUVE)yPy^R7ne=a6eaK`VF2+}7 zzTg&D_2cH}oUA{FWrcj_M!&Y6q&*G-kDD;Xd_do7{a2`ULkNuEvxP1mxvsZg7L z`J9K%AX>T|SPCsJz)yT0#q`E=NMl+v^HT^YIJxue1&&L(IwRvHoWM!cSON+_ue&Sd z85ilVWgexH-(*(#nOrONl38tQrQ`V%*L$QuTFqi?LNO> z_!r-|GDudCRCBPV739hFXVc)YzI!v1^Sm`eSF-J=2i7M&N}E2v(B3Rr=K0dzN>UpM z$ZZRXb4O!*m553na8%o>M(BJCcQl>uA?~`2FM-IDYtt!?bZbgu5Lt<7!;1b1}Krawk

D=YC5d9=Y(e^_qnMxdqW)Nvs z_QapnceXW&>`ljcoCwvv5LDk>y#fp%Jye3Y4xZAD-Mw`pwC)OYykgcG^D9ivcUiO! zsgo*xD@_UqEVAC_;Q_tdH6%*B_j&}Mf!s?Mz_=f$IWn0WYkqf_FHjfr*#Xlb^~w1_h^^VxU$Tb6slc-DfA; zaq52Va5V$SZQ@VdJDm2p?H}epU*qaMq<^LP;?S5%x#z0o@`~woj=j>%tFaW)aYnW2 z7)b8h%fM$+MUU$a7saakJ5Jfu1J8 z%wHag^f!Mmf6-8Int4aJH~pb*v6)6vqKE#9VI8{i^+ll~(@YZ%0}h>J#W|eCy2r}Y zzepWKTzFW4cs{G0%?eX zph}@H5GncTXrN&Za0U0~&Iuz8Ji*CeJ9XWm>S6&;&kx+$PQynJNc9XD2Q7gd_QRtJ zSj|Z4v-cdA;3#?#ThSKr%r<%P&br3iDxe{mt%?&8l!%&HBT;Immnl=fxyGQ82R}v` z29A1r0MK9&k(|a00B5x(*b6?hfC9o;BZ;ecyDuLhxK_&N(9R|V@ELq?>}zQ;^am^( zOT-yUQlqKGODTN*8X%Y9xU9bk1=g26!U!2jL<|FJ|!`|+js3D z2_pNHz(h~pr?^e8=EMZlguYG!c225>N6%HdCu!fVe1bxZr%%n%`ZyO#BXJ2uqRP8U zJNpy^SpkB-F(938)`C`CF#>;d!LY+SDe&W92q@pgen~%oB$Inq+XIL3Arj(w-h}96 z7EIdj1tS6K0fHxsQjA=2Q^xaTda$6WD9XaDOm0;3{?gjCxeS!Gn>i8 z^=gdfUNQp++L1}OEavXVSdcI(@g$oZz#TP~0O~DiifWD`fsEN~!8a6!%x3m$T%aSH z#AFeiBNc?Q+Ppi!v2@`NFM6r9+}PRyjgY8@M*t`{TM6^qTl5F*1SIE|BADnB?LEF3 zP*bc`El81t0UpNug@3HRu7O}FW+0o_DDNKNSf?-S&H}dH-p5+^r(4;EA8uS(Z-R?4 zg37f7b`c|4_GMhCxCP)RiZd2gaLUA>ck;f-hYTff7`C;LBmLU1uHlyk-ox;{ljY$# zw;y2WH6@+-`ul8_G%24`_bSl8EB$;=_0-GLG<7HG>@9p6fUVNj2CLbe%SQ&hKk81_ z^huFq_Ml9_6UG6|f|O3921SV?=X=QuA$q@}Yy^K2`tfOh@(Oq`OR(cH*p%~ioOhR? zc6+&PJ#}xsVa5ZWn;Ul@`b45FIO`yhr(!_RLB>K(W(J@YY|fpcOr}y-x?DYnv*iTV zysQ~0;JyQ1$1%4{7>0 z0*<7HmBI1o3y#OXvvd3v)z)_rW5emnm7D!tcIRYteZfRO?r8QEr6twt`~_Um}M zvb5d79G}|)M3B$#uZefNMUzlT3S>?TbShHp+fAD>Vb-BUH6=Sp^Q4!l2Q)1=mQmlx z?b;-iaLWJEnfeh-=mnSw&AnU(s zvC_@7;Ske*!1m)qIl7`djGv#w$5BX&wdCvD>mR;@Y>=Axl|+u{2J=f($O4Y@@jYgP zwueI!v%A>I6cefCLR>p0Z|+&;ue7zM7Qv(a`CwwB+h?Vf)erbbAMQ-&0ZLtgD+PtY zG6D9^Qbp{-4$w)(jnADPY-T5{d+VRLVRVqGtohNhYkS+`#a>ZB?N18EF#P+ZnWYTd z+WY8JUjnj4-RxLnb$4NAPG)*KG+|dFe>@u7|S(1Ti5Hk z2&WezR2ljhz{TT0_<-Ff=8Sjff|~y&6)#DpJ8A`Z`ZF7=8J6gU>Pp&w~rACKIbrE)O6=`)h{k7Jv5{}R*GrnKG zdL{&mPkJJ06;h`CErxqvB7zdB#O{8)m*173&EAMUMHYBS3`+j?izaY-*1v5t=V8x< z^D%g3x9|O%eP%w03<-?@}Htd+-qhWu*8`E;#2c1 z6yQVWLti})K;|z#sHF8v{!#5Hx>P6IiWhR_^NpGjh*nG_MqB<}^~IY+_d4=z4K?=) zJYc43)j{vp%S+?$pbGoQjRvOwKw#MD^b0a4#2PJT44hy0U*RK_^jd9$tGqEkpxTEq zNa_w!Q%s*FvP_)z03?;RxYv%%>PoU10^gFNhtTK8RW7jpv+D>$fG$e6+8638%bhF2 zxf?q>f>iN=s7;1G}OtnsI0cW4HdY~xWEzb z+D)w>s?~KVvj-h%wtSXi9+q;2fjkWA;lxjJk2xLb=3dys489{n(y8S?OGmJzbw1wo9WBpx~v(-02a0@C6cA*-Rtjn|SnK9JPCp z&}0Sfv=SmhkYz`N)hd}m}=skK+Uk!alMYUGlzXzJ|xqJby9Hr-J= zj=rG_K?MPn$A`q|?533PsoS7Y$*VuA!4iA7{)^Wy5BIcHl^SL{75cSGEkw9&j-7gO ziHn|UkNm}928RR+QhINm!c(Du+@}nU*?|h@<&zPCXIBAxuY;cFegj~h6E25I;^)iM zN$(d_vWC>o{Ct6Q`)pP9kk;WI57qSz8n|aBWCggVd%x#H+9fd2Ch3{}TmhvciJsgv zI4JB{2yyi&aH_z3KTk&7K_F^)QH9;Z6tE-{3!!u|)`tTy+Ej8M%UH>xW$n>Rgkmds z(QqJQe!d8$3E`fxHZhW}??3MpRwl%CY^uTU5{`DCO_2h>+#PzNI;s`NkAiZU-n`h2 znh+=@@wtBwx%LhE`*VW2I8EiE2?aqp!k*tX4CBQg_+`RC!xBlZj&Pge3M(ms%rWra zVSKWK@ez%hV@N7;Q38L`*j4=ru#)S_%Ud)h*kk9zWO2V@jJ06_E;@Fy*l_3-m-$X8 z*1`4CQ=1F8tS$l9`F8w)RM&TjZkBx)c@sNv^WvYNTHex!esyiWygK;+kc4?FSu1OS zHnc^vA!C{b=9}yUc644J`l^)fSh{sGTT}2KGrLrZ0`HQ?HX^#)H9IsH_Lo^A$S$DD zxJ>V|5n&L6g`vW|Q*5O|7|W=)I2DG5Mr9!ivjZ|ci;KJmrR_ptYsz=hI5Y`OL)!ee z!*x^yJ=>nazD$n0jF09lnH9B7JBy(>87vuW1YYbT( zJT)`#CW5!W@IUs_1WWq3&bobXrv^!EzgclwSo@2nXD^2SNn1rC&;tz|ylxgC1?JoX zc7PZ{_|j^Y8Prg2eU#?77GL=J)e}wdGRe*Pt#5k1*v&^)da{vO7k^tm)(&L!Jkph&J$QZU+Ty6RVaL80Q;VQC053L!w)f# zuI;C6j*4?cPP%r?96z8({28_iGU8^WKyHeSHc+dyy&^A_p99#$WC6Z{tLCY6-dOCB zz1|xLYjo71n+G}ywXat30dR)6(ug|uOBmk>S<=kkJm+@`zivM7)Ft(v^!%1zbMnPP zv~mjFht-D8*hOS$0OR0#w_v}Pnq&*US6uKnkv7S!yD*+y;BgY0C{f2}PgU5sc27)3 zsLvwhup^AxVRdTLu40hPA9RH-KU`UF8!sTz)`2Oyka}{eB8mKB5!}ZJKB}|;r2-+arbJ_0|oNj;fU`cZ0e(-^cTi{GBqjh#jXzKGNEOe2SrU# zY3nT4%xUW;Q}|E#PwSF?zin{;_4bw7XsDm397`&D|8vqAlGL((4#dzYuI{c$T)!Wy z3NG1k??hU1P3d8nmBxg0gJ(G6+cWQE-;2eju!~TJL>6$D^|Q0*6ma(NV*x$NFF{EG z=8u$pw)n5VsFuhqER?s~RLA1lALBpYD|liWmdjo1fSu^_0gM)bcDxMUTZP1$h96__ zQcN!Ub@k~po*uEe8k8~ZBufqVFc?OfUWR#O{J2jE<(TczjsM(>!G|k;D~{6nIoI=u zF_o4c7{>!j-9W1fj(OmgAh||3b!T0GF3bh3QjF*al1KN}tl5;;YC!qr zHR{;0P*65eO{XUehPQnly~?4*$Lt7bHP|jVLPQx+RvHY-BNp0f_37up=VtzO+>L7E z?5iTUo{2<4H*I%SC8R4TvC?#K^(yVwAxf5Id(N|wAO8I_Xp)#5Q5bLapy~J4BiCe? zgRr>7IOnX<*&wK|z2IxM`2DnkzSmE0xiQj7E_+_;x8@d}(%N|sVn&)=&&Rl?GB%y6 zkZnUR%zMv*WgO466?KV5T-8nuhOz}E9D;RU&pJ z07}cjKGf>sGf)ydy(r31G`5XDoF^yLEEXnrC|YUbZLnl|QvQ{SCV?^3a?_kOZ+lxP zwG})nYG$~(_0k9gj(8TaXGhOwW*}L7ptvFQ_j~Nxn|;s z&IKhVOntgieJB`#otq`v$8Hw&vr5D-lZm(l*Ks<`D_ebrb z=}zr9v);}l9~F?A@y)jG_n_UP3w>EHPrf(B!zjZn*BjkB*Ver=Agsc-h05--KhYMy zunXrf==D<3GK5ugn*L%GzBNI4yfTYftQ$nIk~rN*JwB?+mvL4%21HCjgWfJj<1?vd zb7ar`ax8~h*Wp(}F=t=(iI%LmmaGB^`J6IpCkm2TLnt=k9mKS{ivZfRN#_AZ?Hb}m zU$f00S@g=RYx0ke6`IZNH%>)ke3WwAJg z9ANdMQ~GZ2uA&&B)rRx}T|ZW1!&=JYheLwy=VAKA^;Q)IhjB{=?@DYa{tU9b1A9f9 z$(N-vi0qQeQI((NpU@{?r{U^ZBywV$)ZGgh3oNS@dOtWUq27wA_WSGCZM0K@b=7QX z=vVpy93e}^3NF#iwnI-^qw6^-k-3#ld&Z;>@$?bmNr61!shX%zlO~CgeJ|B}Fwu%P zuXDc)))ho*-1G}w=%cQ`sPwI}Z0|93k*{}gM??^lklXO3Z4k*FjQrefxI$hlDw&6i zV49lxNcI8u3kjDkE$TLvv}$%?Q7zu2&3>$#_0BP13&_ujkl^5}mcGNycP*ZZlREr? zkjfky&UcK0i%6FRnhe0&ToD4@{kNxZ_4wW2=1FtM8<6fuz6~=dss@yt9GT>+{$mZg z!_r*}CLHBcqAYxuqFpa;_|LPgezQxLnqs{sYP}3v(>7h`M&lsDVG%1G855F)g&KzV7lA6{EV?;^Q%&u;wlOnbpeD+!29*6y-2#Ba}(%&7=%cnSTgnm(4r_X=mxnaCxIP zTuW!vnEr^)Z=xb|yiG!ZD#bJGj^qKM8-QqdELjYzpAAu@3dDSQT?$zdC?|f0&m_Ob z9{We&@wkU)rJWmiJS$sZJnViGO~OYCMtGN%-rsD^&Zf$FEw2^hQH^QLa}iRnnm(^b z-81k^QiFTI?u=g4mzS+&+A`l2n0jd)!BE{5VJ z51VwdFbGN7?gQH8OvY_fG^h;hlU3l(XC^t?%+I6QOE*@OUy&85gIeeMmo^G+(&Q zF;D@UHv}rj?6m!MPN%f{W)$BS6c{z4u=2gmu0}|nPlcitG-1CV3Ar*29=%bN(h@6guUm_<|Y?CySkjHK>i>d zJ^cyw6D}a=C|BHnD{u|PuE_WiOq@x-^C&-nQJj}Pcx?dE!zvCRDso87_=?5FG`aTh z;b`3P(>IohP)jlBoFlSth zqFRZjPs{j9uj9JyJ|}$-toIyfI3gAfv=#aLx+|Kpw5)4(|=5zsnK+>tISWj$fy z+na8HoyYsW4#k7=K#6@T(;11yRjZ;`)g1y6Ii;f+O+Z;S&nSOAZy#R8qV-;R40TrX z3`!jNZ+;+zQejHPmkQ+uDvy8PD0xrPD3v&=k_v> z)QGaXnc`B{%@ro&1X}6|`EHQ$^^g7%lk@9nZdTt;t#^IxKm9DtzR$Mkk|7qO_Lt_X z(RlGPCH^j({`i;N21J8OHEvL;Cf7Z+wF4jckvKFg-qp9ro_c?dK?+Zg+dz3^Jm#_` ztIbcm>eaP8qZs8f?dPB+n&~q4`HCW2b7|fG0sraHL{K6Jc2mb+CoonZ7((X1G^0Ag zbD4L>FkCNqszP?1xYb?rlJQNzgnNL8S#Rp?+@^PIN_$C;kLu%2cGNu1| zgE%bn2akO^-Np~$RJXS5!e9_(`@|s1u=qX|EDAe_P^83O(N{xq?WcoN_{hF?T3H3$ z%`CHN>)W3&g3qO+YP#ib*{(ijIrB^2G)r<+L*N>pd0}!&r^XQ7!H;4G`x|xFRc@Zr zYikTOT^=!u^)|+Jr$2gcGy@PWw6`U zgObWDe2h$~MWs!rqE{{p(P{kQNX_L&VL_wt}^4<03c)i@^VS@O24{X#rV zaKF?A6tiyq;UAtbl5u~Jyt|IcRbQ+oElcwb4^*1=B~7PqKy&eVvFA#&GqxUVU}Jrd zA5kM(QokM1TF?-|YDsjz(I*que|FvX$-0%UL8GU0O*4wI8GkE2`-8(zx+H>$i-a_- zdmL@6^lUmraVy@ZkxIn;lOeoG7=58X84TVgQ+77EEuwS_1Y?^Ts=PGb%-<<-oR!k? zC`LLux5VXU8MMcUSAc_~JKU|MW<^<^N(fZ-vj9+LW_sLCp0on-`gpepCf+YnQW&p$a?jv5Rnd)&?wi+W&yAUbv~7UM?pmosHWhIauKd<3aCP)JLQ*TT-5}vd z`~4Moq8!lX)mXuZsg)&XG;yuV*Eyd2)w5Q?H(|A*(s;hNb2~h)!!&{V_lFjP0q7Uu z<_bg&0d?{wkWWNRph;p)gc_6rG^A-*m4om>L?B;`U{U@I7eC2$49CqXLh>=x{XZPd z5d4RMYwY}LZMfDQEHOVA9yGMI3i7@OA~oQOEcuv5ok9x?|-DE?+j6a2?;f=&muMb zPj$j`%S2ZJlCHl|qih&Btv-kCPY&y8m>JcZf){Y6jCI0!Ge=ZXb1#UZHNzSjHFJ8pk3cJ=ur-BfZAV#( zp>e}xoAYQgQU&JctZNkCFX5Mml6)JWeT3);rnaZKz{6_HEg1N*;t~jNj7=?|)2of+ zVD{ONv|-F#5Dkeavwm9*jTD}&*^-KHc(x|pc>H2*9~93;5{&Y6VWUgayDU;kpMwtc zjO-KH<}c=h5|y-l7V*qh=Pb+K6yg;%HkP&KW7GG}gGzAQH?6_cUDOyR zk7~L%>^?b7T*{4l(EJ^ab#m9*dL8}t(r(!@30bn$EUU;^_u;YnQaoe(KKO2fU7cQ{ z0lT!j;<`Ny4d~qF1d-Bq$l?aKfLUUT`@OkG+k~}RH_j*rUZe*S+(jcvamMsK9Dc^n z26_Z?-=1eRc6sw@g|%XOAq7ipy4AFj`aSqJbKQixVDhN+Ds=^kpV z_B8Q0w^XhpCW3fp-p&#pA5rxfmwyqqg@o4v}WGS zYWVfdaa0+5J1a8zy-o9oBDEV~O%B?~wJ^QDZvr-2Ct_>|DLIz@2Vcw*`<)NY9JB!m z&qOlfQC=bCZdo=A$R$3~rm{EWd{b$5(x?9{F0PBIz#~m`9)<0ea1J^0$Mjl$oAx1Y zICef|=AEtg^JrFiSnYn386D{{kH|(Qp|qEGRDDQHnXG%JlcN>La69vAS3DD8bXci6(v4Oju^3zHH<6FUkGU;)mITqjGD$ zbY{^g46906De&W1cy+Lcj~e}ngKaR@zt=nNCfEQpoEqjn683R+&czzM5WV?hqE?B6 zBtj8I`>>L@5@p@NR3gJA0#PNGNn&5( znCAVgFY@swq+DiFE8Rd(+EgGll8P77T35C)Q!k7p~(kO4nw{&o34-9q@) z&?y3!g_FrQ>r*?HjaUsRpQWZJ;iLro_0>AV&cLWF1B;i7)s=P)y~9L6$+?yW0>UZ6*-VctgH1`rle7_7&&X}>l*Pscqt%vTGBKcre*Km^cP>=;>kXS%Ygef9P~;Lc?&eTt zUG?D$ZVExvAv|lCj$cD$Px!++0^LJbTup&$xyM+p@fvZbkV0vDZ}$-i-5lW=2_#Zf zo=NP1l9;)D8vMcZ#`LdD4$bm5`l+hgSVl+zp1l{-++#mU>*wCs-}$DPynGf@v~l%d z-pMN?&&kKmZ0oVOoK9fmEk9Pe_ShYiQZ~aF)GjC_%fZJp{hg`S@|y?0SRU-Gv}wR~ zGF%gOK10Swaz1=<-4iXzK6(0gPXu&E!sG+U6^!4nccmFJH-WBP6ZVP6Uug`St$}Qq zEG(1{LPj-dR>pGFDpn}E6IPHx&64N@N#2P>l=od`aDL`1!STx7@Lwbcnb1l}g;%UsYco0B29iunPN<&@E#i%+7Rmc~^I`8&1Ojc`@ z20dgZy$9LJ+w6+F906BA)pMmdJVCrIN|J=|Mn^-;=sA@A zH8uHi2#H2YCqdq>v}}`XO7_b~ZMecXJWDy2Pt}(FF%^;TV(V6Z?r`VU+AY?j$`=Lh zmYTXl&uw!Y!<7y%?a2dw?x}|xSAY?tP1jQg)^GxqOvcK)Gs$Zwn$Rs`kEwL#ce)vq zu%`&I9*255 z0?1`Onv+#TwWJpvN=$54T$vVTm@ghhs1Z~TLSaXoOosfM$ia(yB$!-U)`tAEwoV^k zM@adWm6#E5AyN?fArB@KtQZ1`3Qb1>tojs``Ne+3@^ER zk}MZ-yL}{!7AQaQ@*BQ0+;N#M5qJ(mF-{P51MEvvp6qYfZ^JAX=D; zCz1DNnbI1XoMQLmm+z>0k)kAc`jF`GM5r&qRVq4!9ES&#O>;mV!WRf2XWm_sA}xBu zAEx!JSan(62lFPyv0?qe`5tt8N`OD&7rI530p?qyq)<=Ly0VgcLj=Ds^`MlbLJf;+ z+XpVuzssCQ#Zp-yV!vI(w$jxQnX@M#|+{U8L z%#>5oqB~~mdjK9KrtH+-6fpdNPU)pxBT6BD(yI;6bogsYI>`||NjnFgsq&8=U$XGp zOEvgAS?5$~pqZR>A&Y^ zTDe&FeB|S!5Tyo{i^UG9o{JzQY98vlDs8wlgm7+vn3@%2NRRALBP!BjOUlZpCzNp{ zGHKeh;@e0E_b}(OvrFmKs&oI^)LY~N#w{O*g?%hY@J&j03r1KEo)5j`ju@c*yEyb0vv8w)Q8z#a8FNIZXvJdxl4TG|CxaF5j zF>{Sq1TH}<{vW~fo)~*$l-!(WSvi~@SWPGx1bdA(&&ZfF)LHd#_jN$GicRZrE0b8%eg4GcS?B;=hR*Ow5ty%qWV+OYMz2^6oGD1xIQ z_*}ZVMeez`~NUpkeF^oD2g3YT3 zQVWwOrcGidQ}~K%6r_Njh9#g&Bw2ae+$#|V`owEqKI3Ti8(=o{wa{2DMJpuNGipISr*^A)8W#c2 zV~psCN^__$z;1uPYl_{ylc$Xbdt~;xqzyx1{@|;!q8vom$uC`JhjWxg#S8yP&cO5w zc6(!a=G|CP3<7*p*1ki|h=o3PRyCjPFA0A})SQqmI~68VS90@yE)zCOv?ibk=!RPz z)hsGc-=C)Gm&P0vjedw9BDkYvN8Of#4sPy+5yfDFNV4 zb<;nq1QT>-_7Z#U40$iMhjvwDU2LK+-ro^^(%civl(nWt@?i|j`V8!VZ-&B7tYC(M z^Bfz$3pzFp(HF*<<2C|62=)56NN5H=w?S%c(jOl~#juo#c<%k*<mKo&hZ#d+Dqw+X`0lv%TZ{YNVSzj{y;m<3Eq;f9|C}UU#d5f8b2w z;V1MKyq8ZcoSxM(K)POan^V(wT}LTvk;0|bb^r`Ojf_FBuZz5SClprGa$#Qv&}T{A z-bq+4{DF0f3&3Wc(+@Oh$xviPXTW(&%11RzNgX{qsDDirocbs}M7^=v-)pl$W%Pet zn`bcH)Dx@r_&zXcXq*97bQCYu8Hgr%XhyCSps=Edt4j-X zb!>xDLAipj3pw^aCfs;>5893XQJMTR7XFL~XOdEg)l6r`yc$a>xE@&m-W+#yH5mVr zmiVl2Sz8F0)uBBbV@+u=W7FR?+|`f4L3Wd8NxQ(7=l!Fu>du;M=XKRl3s z1rg+u7%UfTs)jVZ@^uO4^R}pmp0QTkyr`lU#diMp zq7-KTcTvF71fKkNQT~V&x-~iSzqA1U8#VvfT>srl|4WMb|7uZO6-vuJkr%+9=Kt66 z3YjB3ZfH51HuMFeo&)Rs@bYnU!W;bDJy!Us0{oh_th@g1`YL+wn zFZc{t%DrUNADL|aeOT(s{T~j?e-zsTW?%jq5c~wzt{rJK)<#N$S3QZ1J5uHGCosd$&+49UR1Zw0_b+|2?HKxo!6)| z*I(7|AK5w3@II)44P6V2wf^;kl!MOH5{crab)it8 zCk_68*42OK{XlKhrRIA`M)N<8jDX)aqmO^wj0pcN*{2v`0)IoZ+xc03ub}Zjhb1El zOf^d2&i~fP16xB8b)(VN%AmF3eNy?^6X1`8b4#r=-1LCdxp%^uBj}P_fMIKv%RsgS z5YcLZEysi90Nj3y-gm6N?0Z0^r5M%G==02}0QmRAjBd5)9bnB@hjJ~A&X+zYz#!d< zHaoXKqGpdVXG`=~wFcntUsvgQs|MjEIZpOJNrqw2AIo}0p$M_a5Ozs(-Ykd1nZl{_ z&9*r)d&~(SnI2@l{l)a45XJ@ONR^3}36w_QURd+Tr+@mNpPoV#HOBGI{0?A8`>4*% zrdK+_f37PbSXa--M07<>hf`f(cxhq>B=fM^<5H=rfu1uAfGBz9kg*=Of!$pXJLyZEZB*4#)$iV2JTeiy^pfvP-0CtD!PBHWAWLfIAzO=dhJlF=%vq?bLxdJ!w zpN}3KrNbrjJ;85@KMtHvz&G0bY`bT~vhkUjfqxQ5_f0`cvf)Ifko96Pfk6B0*aWK&$><))n`xa z?s@ z*#OLbZ&0^~vt1@#j<(9VV_Pbi)^EaFU#vM9+9-vP`4rCxS>#hPY zZrvf|$qDRLLd~q|r>~&0=m2)))MO(d^5@uMueb@TJ`BQsr8-cl0(w z7FkPuX$5FEk=jVMMub}lm+%ir@NUD^zqdnLMN!ot&6Jwjak&3+I1Gh2sAK5_b^&`X z1<5Fsp(tgH0GMxfV>c-238|L`i4ZkLEw{tEqN9Nrd0s65OX1l)M@jxfB{wd!*4?-6 zl9nX+3f9WX5ZT(E65FxK%8%cMn^`~s@5hqV1AWFQFGu+y}FEqlKx-|Jm9Gx%780Zjtb6(*I%-X0rWCANf6By6SjpvRQAzBz^}(Txr8{ zzmrR$Mrbb)_x%xfjwm@yE)Wn`W6yC?K!lZu0>jqm+@vRk2IWtCbK*aoH{myOR{(S~ zS0eJEHL#q_NKiy%nLGF9O0D|>T#$N+>a16%D*8AO@DE150Z7)bkn6q5Zwm&5c>tQ z%N4UE%K;=(z1fxLxy+4mkxH?neN*m!>X|6Fc8sQTO#PJ%C7GJgE&B{Gv2Suc+rsOI zwugGWJO+jw87uH1iL2qoy>JOp8OCQ4f4WzzaTzo4IDenov5mc*j<}aoK48-9 zFWzx{{;Q_O<5={c@i{^4KbuHGW-1m99*V; z@GI0#c7-qdvi`!sf;p*z5xybKYAO6 zd^l2jD#|&8z-Ix&>P(gz2~^PHm{sBC=)NSw6bJm>4I86c(!S(QNe{5E=k59ts755P z`k~H-%PZ*BIrOJcb_;P5F^7z~a`L9Y=6O4L^RN^zw}`+>?Lu#E1vq7sFW#_Attn^p z0F{9Z28vb8xU9CMy>N~m+a8e6*9B(*(nYPIej2soW~-x*CXKNb+cXGcUYedwO}YmK z0!AB;pxjD{h%R5YWweEKua{} zV&$zJaphuVfzP)V&Kdu!i;Pkh9cxf2Y-2K;z)P-9OC+crYR%cXM{e6`5`CK9Cugbjb zeC!LDE##$923``Ne{K!0gZetKQ?N@ICnp1ws|@}!#onVTsl^bI`n4NvD$v@t3Og@J z-(}H|+_=77M?n71fOu}Ta#QU03x=)e!0z{N%ME|WK@Yfv|CzeYbJpu}FKgawiF*D= z9he5?(a3DlmfoIRFwfQA-RA?%`oV@auvPutE;m{(zocCNhU%O- zy*-xOQy>?&-3TKsSv?0BkS#i)Z^;W*`Fj9!SW*F~!|qAIkx9^h>c&)L{poM)mgEhD zVLLDeeUyhx$d}A*^e_j)6JBCR{`hrmvcK|A_>AsVPC;1)phxu z(no3jglBFW_%TZX_NMli`pvtE$N314A26XZJ0!-3#w(QNMs+NJY%_i7TfKUeG1cZ)}3$!phl2PQ3D^4T*^I zi@p|foX;95Pk$J?|Fr72Rc*Qz^!*l8IWxG9kc};A5(#j}{s_1W{Sp17_@q*Hv(7vy z?DXA&hI1P8J^iGRk(%5a3qYi2p2I}N59Zj-E+T>uT zf6y>MG~4;F8WGBH8|u&X%RNkeO~v-ieS`0mgrY5KLpkkpt81Lw4z~eLP?E3(DQv>6 z@$AID*ke@pq%p%3{iW%68LrCJhudLcDY1?QeULkAE5taqS6Ga*_H|id<&NLi2y)-t zSIbckw~Odmk%Cc|cFdTaFdru32ZsVx{fbcowA6s5#0J zg@39cCdFg(g|Os@DBNJG`T|PTPz<}BA!g2>--rJj3Bgb-j2pBquypfn&^wq>273v? zRsGx*YBj!)g~UL!Q`=AYe&qD5Cizm_=v1E&*2dl_v~p>mRQ#x1K=x%gei#M(XE^^7 zpoI;P>04mpcj=-rv#c=$D0cv4iV+oeqDOt#4|Cb1 z*7yws%ThY5UR2|pc2QL$VCF$-kSyVnLAvQVEu#`?vT-Vr2E>W^^~|w23@;|-{4&N` zl9?;&{F{>Wqs!Pffnow(c(6G?uj?tTWwt!GLE8eo%r?w3LK?U+a4Og~Zt0f>F=qG> zJ?WI2q(d77DjtGs=eL0BZ9zprg zqei$?uC}N&^jN|oqZV=JLn(`!2=qRIRHc#rQmv1YsC<*y#r~M8fb6Nn6*F<`s^mj# zIva6d>r&BRr97q-(2-)m`%Sh50KSsY4YK}H-ZWl@O!7uIPd_~1VshX# z5JSz>C7&OSNm9^1zRUuZT6un;s9eu-=Ph%wkB`n-T=k=#6eV}5hLPTpb$~RKBmgFAfymhw|`q86ti8g+!qBOzy zXlJ+1x~=cioIba~+C((v(F%LC-}K3@zuBbNRv4Ff*5%Km7cw**Wu4BSe^Wx%A3h5o zD;+M=V^**}2dGgKt0g^?@`4rytwmN2#G6mX4+JQZG*YNOK&3uuv_N4o2RqIL?>x90e*7ayXlOw8Np^!ZX2(Br|eFXBel)&A?vylaP5k`Lja4!v%#@<0!rjD?XBu6>X zJj&&5>jM+qbcunpXV=guKdhypPT2utbo}qD7a=@CgHInX_%~+($N{qiS{3PGsi0t| z#t+QvY)(MTMw6LM}B>7skv+-s)x+logd*<<(4E>Bi$)!u~6hsfp+^Mv3?x%`sO?u&o4dTzLl;`LX3VeB^~M;Q?2c30p=_nsx>WTK{-)Y`-kF8#;| z#}YwWJRj$`|7o-oxVg(T^#@3zJl8dpKRGCowDqB&^^XkXPHkq{x6HBiO!wYTCcx1? zp%6!|;7KQGaPEW<&69hcDBm|pn@^JQy7}Y^>B$;+LpZ$@)6i>sffmzIYBQJvn&$b4 z!K*q_^}JqI>{~Ly5yKJFQk~=^3(HE)jCaUT{#QuozD)z=cJ@{SP}@-@(0HMFd;#oI zGhC>?dp}@3Zw5MZ30o{%x3;)1NM9#vS;MR&iQlCpRD)EPnY@$vfHrwJOgu%SRV_<_ znEjEA*AQ0%RO-@GnZesYJ2ypnnl(=Hag^q($j60#O11J-9r%qbK8F=qJTe2_sb5h5 zX)5OLL3f=_`LT6q0(=3#(%e{h_6jj zs6}T`pQMh(^)f_Tevi28K<`{QC=n#G;NE7N==n$Id;=2 z%DX_9IYII9%ElH1^Jz#m#ERc9)(U-zp;aR2)B~lVVYA6#cAA90u(+EvFQCECCK#1n z-3~fH!*1`{_(joTj07|J8T#MdqDts82x7liMXy`@^9J-{BxT8y+|fv}M+EI}l3}WC zr*j$~5oULl=yOFG-%#Uh$BzP)Gqvp7QI$G0@w_)!w6BL*SJ^w$aevAd=B*h}%0QL_ zjGux08f=mEW`Yj9vz)~yLIID$OlJf;v&kyk+9$2I8tC8aIar@M5a;CcR=!Y z;W72jrbrLTWWqQB=SGH3v~7lS#Z0@g75iFgV%5x-qsq?T%l!*?mxsb_^XFX~mAQq9 z1)DQT9xw!HiHkXZ&yi=n0@Mq+JD+?2rYuffg`@d!hPv*RYRC`O`$HsDmczt~&dnFR z$h!qv_fWrSUSMtF@0p7xM?Oby9v`LQ>G(EVUNtHgxl$L>+hG#{IP}J>wG93tv_B%p zT1Oe{-951p>{=d?c)`X~&S~u~!r*sEW;$#jH0p}(k|H^aUmbmIQ)Lc0=g`Iic0yNt zUit+CVSx^Yfe4udawy-6qvtIy{a-7=RyI*$!kqIS3%IFya0eclS3Kz052`Mu{p^i* z)p6;Tn{C|oI#ZqMC%IRuDQzyQ1vNCWiO)mf{--q5c@S3kc(gE$ImJwEaY1X0HDQZ3 zyKDc^Hw>B{h)gubp_ECS!PAeHkU$-bSmmbYT~MZV3xweV@fFu}8L}dS(vMO1;;XSW z3OnzpAH{d`-;WTD=GpVL z;Z`W16xo!6Si7vI6-(veB;;A@(v1(|%}%rJJG!00bnX{BnrW<|E46~e76FD`GtgS_~r%nT8gI^idsf&KP7a$qb zZHKSM^*|ruj$DsW;0WQHMXaTv_E5b1r|l`h&+Z7`k&5E)11C?yq|J~f`rmZMRNDmP za*i|iGnt;i6~5z}pR|X*T2p1CHInJeWLwS@uMTKG2r z2GValNqOF;N#6@lmrENs%+??AG6;);2BHN0>2P@7x$)$$nng`T`*oR**ctk|7mGe( zC6Ax}U_;!N?Y$M*>1n!*4mIfO3@@7#1q5k%-$Go_%UX45SbnF zoG9f>*J<&iHX~|G%P3G5e>zkCoV#$3%~bm6*b>F2ER+VGG#uE5b0OZ{tCR$V+; zDAe3U47@Mc&^!>RPa2FeTz>UA`y)c6v-qYf@2tTiIg5yW)?8&%P#`H&=AEVdKa{;? zSXE!Q{x7wuO-pwxAZ$YE?(R~+4G4&I$tESGLt0w8Te?#^r8}g%`?v6Wp65B|{NJ4a zb@760qkF}iYt1po829H^#fP;HzOT>7w}zLL~3hZ#6sVE zoY=X}IEgGeq(^9HINe)xz94=m$Gp0?#hgND2r&`66`J8mXHl;0?=+=fBY+uy+JTb) zOT|iZ;h9A;+i@YNVr}i~H95|gw+DvVOBe35aUpMt6_fZ_FqO>y;6%dV)SPH~%koyy zDx?xi$?UFD9j=ZaCzD@v#wC~}oyxj&Z@;-0>T|Go`}MnLvLD@jpfZS8>w3VE;yzA! zH8cz;c*tphM@ejq)WlgozTpH-O>g;+s@v&K<}BF)ExgXRvbAf<;wi@*6!8E=!LRh;heHj=umqjAwnEoDZD_ z)4jgm0R+XD``;Hedq~+wA{tO4sw7jm zW(_Fe}`Fgz7`q*2L>r zQhrTMqAcZ{D!N~-{pd?;Sp6QOaa%$dsUWJ}m0MEt*)sLQ&N?R4-Lxd6%_x0l*uJ zKGF+KUGdA4;u78(Xr`Bn@wey7m8qSx&?us7JzUx9#P5whu_18!H~1L^4g` z@TE(wvnlIs^|6}XPLDhwHrA%T$v+=t#VXR%x{(@Son(g!dQdL8nHE41xhpr zxJJl3SA)k+Qht7SLmD$=7z0|Vb%)*>d%>4N8xt;EKaZBKy96VnnQ}z2ZGP^1nHw|H zysc=%G)%ksNn@kEpFBguppjTX2e3Kg0}p)xDhd-TQ1#l)hsC!dbmUahU)6Y%pINn#J@XfY8mpwLcIagF>TknY2G2Sp8&&aTTEf znPC$bJ!=AXVD}!0Rm$F81u-N?z~>XK?RQ+joNB6 zVywPmhulX^d>&+hJt&utpg;{!xWX)1Cw6d1P!KVms5{;@JWv;kc1T-zT-g@+T3z5- z-pvJ6F1z+%Y{OhPpo!lwT>L|ny)|C#%Mux3?+3(t98*79d+KoTRu(%NfD`b7MYh?> zziXc^JBf`Cd*g$D;MSLVv9yCo0l<+PG2Tq4SLz-qMBo0h?YBR>4_%P6#TjCSiCKbn{SyYKrme%m?|VvGE}Zi6bu)gD#kOrmX=Pc6CSRin7q|BarXGL#G z^raB3UYob1Kv@f4+Xr8ZBoA#uyW2u0?bHLE?J^EEd zualS~3kWi1_*G2vBfre#XO)+%+QL4uJgLn+SUo^30u~gt%(Tg>Y;y9qttt^4x=iUy z>I>HdKMY3Inr1~gAR{*=0J04ZfL^tmW>oZcq4_B@1bTCk%~r>R#kTYIGktjB;aS=)Xf!sw zh5lakAoiw5a|&$M9&(bvfT#&U|E%Jw#9_?TY_! z40rt{a5NRkXNIWf0nY8B?Nb2tqt7kO+p60}L)f7k$Ik@!_javgQazCTlSY_ep^rj8 zqpWi0y!YZc8VK)c9J{+pYu=`U>+sRb+Z^i(3=^$mg2H+LQgr+=%e~fX%xr6q*q)J6 zt;i92v-u68)O?_leVKKmlmk(D_4Y*=~Tn$FV=F;(V4vXo(AwlzhhX{*{@;lE24Ow)8u= zpY2#U8VMy7xnCXA$d79x1*>O>3-GL$E%|aMTKii9JiR_UemMEM@a-EJ@J27A_OPQu zl9f6j_9Bn1+kUejppLcn@Jh<enuZfT_$RSF^nK!K^q9bYCz1jW8OGXK)En_R-Yw&k z<27bFCq4XWJ?ewiYLv5Pt55?VVp1gTU(-+Abc)4x)APZ1 z5jU*RBiW)Q15!ZIH}FSU5d*c7LlpB;(=rbPMFM9%5QNS0h{n7Lt2vl9wRK5{;IHPo zX<22PV1~zj{=Twm#giibuK^$DsdAQ{U|PyS6fLlSpnZjSWGuR01o7CDY0>>jmH zh2iW{4}rX~A;V;3)@`D*@#F2d?+m_(0;Q15u(PLxYjafkN{m>oH81sKsk%$yq|tzh zz|Ps4z|&)6-6L-jRukxHf1bOK5XwI zvb!gU_O`MfyR0+;wT+zQZNP73b{~PAYA}>U;zeVm1vKgEjN3uf7e>}m#Jckmfxc13 zsL`xN2nZL3rHI#N#x|jp8@ZMw3V9A5$;|PW-&<=@mU(5Ik{?`KyqzlW)0`#2Do^eo zedjYc(c6N7#`f}Xoz{cwdYr$81^o!cOlJyG(lKP0Hn^U3+0?a;!LRqHt6REvd1W|* zI~|Eeh^9Y?V3Z=`PLIABr2tfb$ynE&2s;3+2Rl=6#Uo-f-SmTl0dO{&w-(1~nQawH`} z_g(kB3^8_+VR{E%n&+Cco?DMMJKcrD@H)eJXUU^^&1yPvd=>ebBynEDf3xA z&(}BKJ@EwCsi@=BTl|F-e`jS>=>Z7qFjXv$<^@2@$*G;d@n_nCO!r*#-OXVg6*b~H zAp-HDA9flx+iOc^q07>-Y~`qC`A+y-Lsy+2O+bg;GrJm zYLjCtWjwCSUP?mIqPz08dz(9cwtnv8dH6yMc$^PG1Sq;=%-?>lcc%v^!&@ zz5%iuvgBe{{SEI)r08|b-#w1ouYLz9%fnQRIBE}GF}#KlXc5kK7R$#AF?kFC^7pbc zvMdhOX1y;OhpQ-nnu!8i7C9o={38t&4dyO_@i%&A4Xa ziR{uP7M-9@CKT-*Kc*XuQQ0KG#_G7)eL--q_I|YSrQ|5)E@TMwv*S5zcSq;%$6W!`NT9?{h|B< zMv$!+UA*R5VG4=J_+!fHMHm3$#I|xt8l=V2wOwCg$lARv*W!+~ca>&Yw9?XszEV1+L4%@`0xfq}^D5MFMgfoBC^yMe_ajPCCHU&&BO zFob0jgbaC5`3#!cj^;^db;ONS#WiL}h=(<1c^c_Jj~bR*TVp^W5SfixGe0v|<$hEx zTT%sL?3Ap-sF;9D3D3>wp*RV>j3s;D@T}g$@^`LdAHJ5ZrT~e-y~O9My0-U9Z$CN+ zm9r;iO9M)>YT1od%97~5RJ@`mw8YNCP%DfiL4lf9n6u^eEwA9v=D3X2$b$_3@3=|a1>e+aZj*7gLCU!*WH)*Ct{GkFIT zlo~cuNKm!{0=ZMo)D~l|pPJ^lJr~ph>rRvUZ@HxOxx)oRmtuzX42hLX#B05C!wMoA zX-5tqnDL++l#`Ev@#fw5b!kUrI1yOP6vGVK04{{}2jTa#la|gpM7C zimsTeC~R50!LVAr=tGz*(rSfrpukV^{8|jittuGNDieD&(7P#5Gc5gZj?uN@A>r+& zZKa#y07sr?{Ys}CN?(judGG~Q@#{}=0xyOXQ4lW|%y!x^z@&1L3J!s2ASOthg<-`% zl@N&k-vpWNuTL^CVoHHYh);svyRQlm#r)dvCX&L?i2Wq2&3bV>)+tYX%>MC=6x;QM zO`j@0Htu#UpP@X-mx=9YN^N-f0*j?>`_BZ6%!!GQH29p1BoJxI;z?N;F%Aqt zjSk*Hn0*N4Qe>+$pXPV%0?mZ85uA_h3+gU~#|*XJaCX{@)$^pCxh0(jxlH~fOf*ac zK!idT1jO>&TK)hO`Bp#2Ge=7`@yqR6IoyJSB_X($j@EbC^b4{QbWYMOk?;u zS=(UjDP#F|akIsOoHlZEp5 zlq5R2dTK_TvD_UBKvA@Ot{Cc9j*V!Oa_CO`@f=)q`3A(2>3*hfRkoR<^srp?t zZzN741GOeSN6>0JZ<7_K{R$rz6UO#AIoL;21D&} zyT#5x(^v^5Zo`hHn9^!BSl3!UEso>ZP|Tf`>?!2m{ocGQZ(9fSjBRAZ-6o=FVf$$M z>-Y_FioiuIF}cMO9@Yw{pw(^|^5Tauc~qZ?!g`kR<}#85LIU3diPeq}R7yUU2v2MD zxw-a%xH)m-+Oj(qJ(zUBFu>)|^#lsTSZz(~ls+|Hk{f2D84c}{^~^A%+J-ye8U%l+ zTq2g@UVat^EFO4@Vt?|$ymZFyS#<5++nQf)Z8qN?U~!&N+yae72%^W2TTd-sV9-+U zQ9~f8H{t1iz6Xic^u?2CIG5O9q~OK{{RKNWw$C*09Wp z&)fBFc8oUw?rBJp?Q-<`k70KQE4N#`@9NMiYdZW@fdamra&Gj--ZQ*2_^u*JzTkTh zpN^}Fb`$j5pLC!^ma9Cos;esSqQHj_Y5S)&tqyTFmF}Eq_Wbn_L0d2%mbn?Dv(i66 zTTMi4H#x_4aG)M>fk_bZW+A9RM%=243c<1b>eW=7Wmy|&!y|ze-A^UlVeLV&zuI=C{os0h zY4BCk=E*Sg7b;p67K}F% zC#L>aQ!FM2H#J@e-)=YDERpmJ>6Cc*o6uy0_VL;&$jlz*CZ}bC%VJ(*qY$^JPR&o_d^k}G))8|OWNiG+W5Q%f+cHh7CUBd$ z1T#~u&tFMia*2=4HSZKKkA(Vj>GZuIoGuj@L!~|@f&q^Auez@SL%5R)+>)DdzV1&W zXqJ%Gc^E0bkM<2Cp&WNx=c3XT|MKYWT*(19Lgjx=VHeN#ujapwhix6 z?01C-jMWQ$>Y3`!-!e*QbG=ey^g8?nZ5%iQvkQ(rdiBpD@)q~5=SHmtmj*XOx_d2K zEV9Rrp31_XS|)EzR*KsXOVGNqWHar!V~q#iAEk&`PC#L->)10IRP5MV;mbRUbts)^ ztH7&p*S^bYm(#pjDJ zzfHVj+TL=d>gz`Z6I;x~*1{U5_)PQjruL9D*octyc&P)F!@b~XQ)Nf1+|m`2ED%9)d$!IuT{^!8Fkyj-;+{aMu?F}=S6#o-FGSarc-fGSYc(x9 zC+s~m4Jbd79t&lkaOMrRo&pZbZ~1dZt8yUYQS=a<@JKl4ldrH=?W3b|dG;$gU+`PrH4|Wk3L=n8diX^jucy}BU{uyo zNenW+SE8iz_|xNqAHF8|HiJl_w%m*fJYl{?59}rJ?^(uj9s_E_KBqlPLuai`@4&^p zOLeu8JuSEqZdTItU?;}wWqiHUwSB=yDWLQXM0q2VD1$u}#vs)`L#HF^m%sV^;(51~ zW(2HzOAwzSv@6S-uk-Y4{WL@8aEJkgR-UozVR~slF8K+tHe=qoXr~XLFl#ctr1Wq< z!1ol_kxzT3N+yOsq`Y==H^fNZlZpz&v&k9m3bx&tVvk=hR1YoiA8t<;7=e&`XJ2mBY2E>a;B5*q3)-jv6F*R zHbp!w!ROB4-5jH+r!GxP#DeE`b_Uy*IWE7Z+isE+;jtu~;8m&GBfz|__ULGA?+YZUM;0bI(qyzbZMRG)uqo_PYDXAFo$ zcVbYViqw57&Z=L?v>RUF!K+exjZC5RTI)b z9!Zdit#>5K$==iJ7#?YwbI)8Pj*`U>E-6Cw5Q!$@MUh7}h2eaeoG86zfoJ&}UJ+EG zHKk=SMFuOr8w_uU)jtk*dn^(K0_{0P^aX~nOG{aTG$5nZg%2-j8??leK2ZW55M&T9u)_*1_jREEDt1Kq@4=lCyrR6~qIY6G0>-0uB@1EwuirPz|C@2Ex6CB2 z#ilCjUP6?1u3*Z-PPl5O|NG=LI6zf-dz(V*(d!W>ddnI7%RI8Y_h7Bh>wB-ekpC2cWo)mMSeAa#|k%VLM#*e}MG9F#VL7zkGKdre1yY$9cynX(R%0jIjntJ>iXa--It|&y{H4FGepn6LYBvKH2HHD<6Mx zA1=Q$BQx~@!f58NZhZyH?wIq#u#x4O`3B8*ZHNL#;lxX8wo~iRN+elmNHxkb917>2 zd_mM3x3*kY7GzCmFc@g4165ohYVThHuqaTf{knGu59F)aT?y*$^l9oz3DEuA2K3?d zc#;UWU^{-?mM@v!GE!L8 zTKYKo#45g-$#Q2?c{p|m^9bal{~9#xee=pLX7$ALM|Nyg>8I}oomX!SVvO@Mld;ZY zAM-S>OhE@SiMnLbuas4#W=C0Vhc_0RACT?`639%%_z>#K0 z=uY`ip^RSKyz@MDfp^i)G9FfJQMh>B5*M(wtetph6|K6cWJd!k}jXL?jj0a^y)-lQH3BM&ym(;ej0oOA;4_}gw<%aH?ZgxWY;`gV(=0&=5?VfPzfM>7RM^aC6s0q#1d#=SiTr<3J)Z2M=7+_x(qK$EDiP zUaIqJr$giQo(6<=4b*8!MK(}P(cLukcTC7@oBn7C+em@`2ByJ_J5YB~vnvM57!lul z2Oro5}#9Y5DIH!WMrR?E4s)+P} z?#(*!J`csAfU2=wvXb;BBWf zmFDWBP2rcdt(M9}N%J;U_T(Y6Rv@C>A2O<8CC(&MOp2#+lWBk(=)UFq%PplZ!Z1|` zv>2w7Epn?qIs`$)i-E&Zho}n~MF=t*w3{RV>}yq)>+W0R&!7?${Z`g*uI|(rH*B_1 zD$I-kHgp!@BzTu-Q=TMr`{*DZ-O1{q>Yx%gdA4bx+e3m4!{6w*XN!lmSd6`v!z)kB0h zb{325so}8_t6}8ftiBqh7wYsnK^U*Vd{aooPDd&m5MflPg66G?$(i2(AhDb-j^7U5 z;IlvGOmyBUP+`6^9Ik<2#N-J= zA%(KCoH<{67-l=`o5!+b@E@UmO$YV6QXEP>4g8$lb8+PvTgEerq=OuEX8{Eg9L6ty zahdoK`ItTPxHj;4c0P!crfddH4!Fhbuo-Wyv!irK!O0)!7NHDvBOcY)`~scWluY_X zL=GSMousyT0&CB|X~*$5FO}(XGxTNEwvoPOwdjtpc^kYa@>FTT7rj`sq*_h$Pr^p&~{EP#;V`!_txmR z*`k{r=!2PQbL|(^>0@<*1ADqUo&>Cov;t=bD6uO;Ck~q_Ur>T4h5ek$2x8MDLk`g7 zS{vG3UH*~JF><9|`hC4zJlI@xBMg}N64KGj2B4N%>e#rjBRm1U`2$4dWP0=AmJ+3J zURoGS;0S7K%z1qZBuV^#{-C%OI2)ntxG`22etVH?y-V+XQ_Pns_yF`2Y{^)TeU^XQ zfd(u^j=0&M3f1@vSk%rla>i%&2j&6I_(g+t;kkn2bVuo?vj4=2a=2-1U zUqlkeP^q=~51pAMn_Tya?KRKH6VAf)r2r%WI-UG?C;5FHeGpLZEGE-N$+iLk;FT6_ zX#LzMfeh;qz)mG3AwuG)nc5oAGVo{P->jUCE3>`VFm97;VKcKjah}gd1D-4|Oc&<2@8o$AT@FM@QU1vqKIqHGE9L zvyl@a!>^Qwc8VQHgnA=QEmsEk6Ixlo-b|iuwL))saY9fJSlL8biya18m`#Uqa^-$? z^Gy=D7Yd-BP%t{TI_-AmZdkbq9h0p6b^3I_e2fQxiyPO5XW`EQ%M&Fv>PQ-TnDBew-Ag}|r(Qb`T~e)L5O zFx_&fiK*rg+vBB^I9^fYWSJPS-M{{b!Kufo-K3D|9u0*b6(vgymIdS-X(!o9@n1m9 z4seqE)^p$CI?x-%#sz>7S0(!s0~i6^!If@4xRvZCmx5ffMg;Q`6-Hp_1YVEFp*?&y zG*9{D{@rL+4?lF)p<<;-4S-|(H_Jw^p9$lSMq(!EG^)0!FA{Ud1a{(p9;*oj>Mo(z z?wTJTK>T3kmz+1xp0L%yA)cjAuePQ6=23_Yp5sO#9*Pv-dWLhg4xS-0UZ0FLR&&l8 zxQiUE?z8oWsk=~Bgd<3i6_XCUoztv*iCqbipKk-yyw855@7Tn8dAk+@##O_Yc!A7z zys3Ql2Bpma)AaMgO75aS!mHeqjnK-*j&jSd;cW=Anj^Q7dWUQ06`xF}p49vtijyLx ziCsVcBwEDvKp{3Hv#D5{g+I$&kXIkM%BsAFzRBX6M;=p%bK)`Ci*|pM*nth(0&F)f z0^h5b+DlZKZXq{n01aGHrKAz3W8=nBp<0!=c6i(^&IIV&%gsGWs`wT4mVqrGKC zhbxfC8+bP8&p9kmnZ(F>?%>@k5n};ZK@Z;R)hX}OmPBSp~NdGU^tJ$K@MGz1ah9-OMy?8<# zW;dCx``dz^RB=%Fp8)Kl9GYG2qc9aAta~>11CS(J=)g_MxvC0n#Hd&Fc5}WhG#voc z)aDD01S~p`mnpzT(W11en%^P7smESazJp2t$Tj=txJ&@Mrb4R@3~rX;HR55I)JALi z$(}Q52&oPN7e*l3x*X}6ev9G6c89bDNNZ5#14N&~WlUFuWZlT4-SJt7g9$I4ep<4{ zyDmz{$fEspwFB4Afn5oKy4pR`-%%nc?#rD``X;r7Ows>@xN^$?40W4zX}tTK<3R0a zGWI^DM*v{WR;1v@(7>kXGSKyy`>qTZT}g)q*GA|p1iJv59ED@xp#|PF7a1s8`mXwF zlwB$X>&sQgf)GOwKyn`oo^t744#<1wZBI2FwSpYM0))N5+Dn$-0^<{*qHwF~l&pyB zUcH=ekSaat?nf=x2m?1(O`%Nv^g}X{L1NAtyv1uVl#rVCmv^203@*l>1$mL`=&RVT;{o7C&PvKp zoM1{dg%9iheON_#4{F5h?MG{W&VWU2 zNo#eWo0=q-W_8m{iO!{Ce;p;y`tne7Oz@y~t^Tm2Q_`!TJ~a-{JKygWS%^+CQyGi* zWO+|Cs=NWbe)opiO*F9j979zDt|??l>VWl8?DuLTy*5 zCO^^!TsT>CE)2Gx?p(+mcbv~i-8YaG^%`p@pjb{5&WuNunF+o5f+`Eb`EiCPeI0t^ zo(!;1TJPvvPhP+%cFl-WujTvA{=N^!0T2to63EH@^ZJ()k%89NZGC?l&ersJD2Rbr zgtjwP<0^d4EL>cr73yo6%7NXcpfhM~xXgl~*hSYlXB;8e=MF(BV<x^Nwi1^dfrx3!X5ej`=^WaXrTKKKNi;FzlQ|%>BrhlBE>EDTp}ZZwhqe z>$BnarX3t3%D3$mN%tVr6B@5B@ep`UIrgpJ#H>z5HJS8~1t0t&);Bq$gpoXchwd6u z&1p`Py6RqKEJXq@Utr$9@_26&t6AS&oecFy)KwEteH$3y{N{WiCS0&cLG+N^n8u48 zuOoAZ=J0i39mKUSa=^xyD3cfI=AhByg-10@Zf2Ek4nii|tpjkHS;zMscCOOv{8FFp z!@X;mpq8;Zey}Hx`rkh`JUc(w2{o$!GR8f=hMZO{{T!N0#R8cx>S}_j*_He9d!GRv zHjF)W{aXs|h%8T9PR8usiegNLXK*fL86~z$s~Z7k@9zXm8fi8~^HD{+%^b0!k&Uoi zu9qn~EAeixfmw9Wa>YBuZda%CARirYRa{b9cRIxHd}_E|R`0b*;PO^8bzc0_)O3#n zg}06{92+Mfy49;8lS?WdlThezKIup5BS42UeDRIs7;uF6X?+8nk;ksibUW_~Jr9Mp zD574?2bB^%2M@qY;a~moPN_owjm`=2{sb(XI%R75{5jY0fIe)(oc1j>t2vloCa+DN z9U%>7Dzi{Uldj{u;MBl2Frcx5^cEg_mBWFR`)X7}q1}ZYn4b*W`2s+1)lC+miz!wP zt7X!kA16TR2&Pqha`4lH;5`|3)Qcg2<+}f*S+z}ml z(dbrV!`bl%#?OdF!~6QvQCM%*L`@`^dA}h#`C8=HfFA+@Ie0cecWhAhnnXvbaut4E)v(tjpRBYur zd%;-saT{a{;(_=A>BfPgKA}!)anr!2x-Q0`7#Tz#raz>X&hdq0Xe)#wnLggHVxKl6jq(6 zkV_O)>zz0mdUA=`&TQp5iI43Z4oLuEgwB}l7^mb2GnQnyH&gbbo6r%U5~XeMaOms2 zr8sP8WE(9hiwHsoR21kemCs45gM1C^cbdwXULB_0A|CN|dJv#<E9Ic9`SmI$T z9qBPB3axk4Vy{siC&(S7uJI~Riqa6S?s~SK9$s&zEh%0C9&YX?>`$DD1H0{7O;%iE z-KSpW+^VEAEKC5gQzBihH(&3eKoi{K#Y1^R3G>eoJ_&bj%zURL@w3h=fR`blYn8k^ zMuNs>k-_*&`2T3(lCw_G1La>vNv{h#F11HRGM;ny}KU z!`jo5lU+GxA^nCw_ZPX&_a#@qqa?cov4nN*YyzA|c8=nIu(I8Ew!BQekIBWiU~4cs z_kQJx$~G&hGv093tL!;IO6y!3aLCUM z$=6l^c7`1eh@`%y>wS9#FMa9qw=9(Qq?24ox{4ohh@@@6yC*c%@yfzs+iS~mpk@x7 zQ%x-(?+splTYvlR&Sd0ljmy;o3)EA=>gSzZAwWhMswa6grf!dcP%uO1gPcZ#Vh{`| zY(}yA?4m%F{IcD$h=wRXm^Pdz0|h}}68ppRHljpD-$vi^DirJH{w4Fm_PsHhfiX58 z($0$TBXJ@+W*@f&S#v+{i<8+7Woj@JzTt&2PN=OUHhk#`dX@m^ zYi6x+{0R)VgW{Ks9^4f|NGKY~ll@BIPl7SpY$-r_XqX1vo_3F)kd(LHKq^dhK)Zi1U_&aqKpWqKN6ZE5{e?eL$tG40&g)d0NGL6`#ex96g*YT>-wo*&+cp;OHp+)>CJf}x5*X` zQq4d>v8$)Vo6cb5wv;QLN=E}A+CL2yIfAL7sraIw?tQ22lh~3Kfd+b2YNMpSl~XnT zZHHp$iyWucspb@VnIu9E2mTH8-KTy-G4PFV<_Id|(maleZ_HxleKZ9mLyCwn0xEma zWg>y@=Sh|oycnQde*aDn0JAm6BE92E zXgJ`gS;Dk}h1s><174MqI-bsB><>KvF^CB+NEydwdBp@9OG#aw{E>ozX2k{jkRz~F zy?h@~yS=}7$HU6Pui>{5)&2e#Z(=4Av%^{YO%?Fi{?`u_^q1+KQT3PUU2pimcwU@o zB7l0PzYwyECr;Y`eIpJrboh*P=}M{`0Os{HFhr=PGJ=y<7^tp8IiE7r-f)Q$7&_jewMzxqz$D)kS!UhRLA;RYB-I!V#Y>;U2b!LaLOWgk7p&QS0MIRUGl z&Jz~UDIm5*-?j~DBi&;d!NUuc4pB$w0}#BwIx>w!e=;?*0|-|5K=-Ehf!W{eloUJ) zuMSAd5TFh*4+-5`-t$}35T)o&`kH(hYaPoAd@FSJ1MBr$k~d+8kBZP{>}kMkD;No# zz){&b?Dgb}+jPx-D<`bvmG?TBc!%;Zoj#=%p1v<&5>ZXG4;pTo-OKV<}w7;}BjWcbWZ5=L_SzAKzYXG^o5=2W<) ze|}|uZ(whrd1l7s(oKlvaL}!$@zUddaY@kZ&%pQ3fL3HSUDUToJOI;aa^q1Y@P}c? zSAU$`>)k61Z+p~6V6($ugF*5`z3(r zVUmwLu@tdmfwnR`22p7Xh>L~`pucGsG~XKeA-bMlW&*JABF+z&T3+wb+&=SWptl7s zb_U3nlQ-sU8`wQuZzf6P6}PQR(l1h;EoNMrE>J$F*-Q7u=Ff_m}_kEgDJBd}2T0Wefas z#HSCvjS#Vm0O}ila@HUv7Pe2UjL*w|pVP-{L<+Y(_HmR+`+dL@w~Uz6^tqyt?Xt+z z4j!%K%UT$><=%98HNXye+wgfGEys}h-~adTiF;Q8-=0M;!J>sJ|?N8;cwd=6O+6&y;wetE&nxvfK$N76Df<5jMK7> zdhxoCj^0u0A@b;tKTj6&WU6>}D#is4O&vTX0Djz^S@U<68!hn=Knc0-mm{ZPE98P ziCKCHGN1Zrn@C>f!p-k+xafq5;}grra|QQzBE5Y2Zg_=%ro8n3obvw~>wjM~cHV1X zKK^7dNV#cFy-pG2cl`9ia(O4E$a(7X@0T!YmW@jG%KemYoE>q8wQcs3onhu>F~I04<-I!SNKbajxJjZV+z)c5P0 zTdLFU%0jG0YK+PC&>Ns8V^ww4!}yeQ)A>u2Go@m}iw&kCI_`-9A6*B;U*lJPk8q(hz_TGoFa+793U<8T$KQ~Jark>M1vKu4cD07u=pFzcB!!3wLu+$Z>X zu5D02m(fD8#9j0_zGnP1-8RRh>G^183rQ?&V`9G0{2G4<5T4Y@;IA>^3BAa`fk& zH>(;wZsnbA)@WGF=0x553jD>h z*m(V`28kx4v1H|PNJ<2g+4yVz8B`aaD0@Wa)tef#?wLvotu0`%{Ef>#m(q7!|Gdo6 zFQ#q_zPN5CWSA8%Rj9zh`!hXY^G6qJY3k=4jG}qr82pXL%*>J=UqI}B z-H;PjM0rvgxf!m|^NOE(cT_mPqF6~H!xFE;tejWezxt1rjukfvUhMwdDXt$A@3{Ld z-uaw^oQ*ufs6CN1jm%=Iz9w93E^q&KQW7JreuM$wOK&DQRk}TLvj=LO)Yu^jmrPMC zG6q^i?LG}*a!N8lrghvrWHZwyYn0O?%fUd^GOep!FPamoq4xBA&JlZ^5MuwqU!X`K z={$w}>Q>nvzDUf3tor=q^9copo9W9KZZs{X^RlS3qOtP(jEi{Vn^|6G&lRKM85Auv zZ&Tu@u#e+uN%HxxD2(m@v?PVmvf5Bd!xFG(81DII3|RTDM`!z|H(7OHWBi+kp5__( zv|khn3+KPM3cusR;bRDET@E>R@JZO#&Ihz-tz^ z6hQ<*kcCLMAfPnT-O?dl0@ASvK^o}>0qHL3mhSFaNHCo%ZaSzTplMF9hf+)*ae)Yp!)ya1B`^X!xILWx?w3d{6 zqQayj5U-syJH6cK*8Ne+bbxA{N|H>q<`l-XTZB*LYh%aoti;f9dwbE5_tZxDm(k;P zT8WQs><}CiL~ToCleZWrh=w!Kj1HO}5uJ4bZ`dEBNk-dBvt2*{Ww3W+ukGT$R zmKW7R>dwlZ+r1ePJXrB_kzdY%Dsv3ieAnOrq?d9yuJx*Yr5}lA2I~ea57cIHT+6E zK2VuWg22`z?ybe%TncZiPa1Cd=|)t2X)Pw%OH25YvMVqJXrtM{6K8|Grkl$*=s#-2yH?tb$QAWn9?ip5jD_A=d;0U-vH@ zaLzVS(oM$Sq`X*`;3@bLe_6P@0rSEgG2XV1Mc0js&?B_aKTz+&FX(e99ThRICu-y6 za_Y*-$I@ckQ;w5yY#y*2(Dur<;I^BnXxdRUcH`iq#VZgiu<`c09f3_pqb#olHYStG z5F+chOyoO#ApX0|r2L2P^!1Y8=uV>5`ZUJgzFr|7zCX8~sg+YU!gJCc$6b31+cd`wUZIq?9EKj4GKzfctt;+59%DB-%2 z@zS1QH;XG?$<>3eww?H-bR>o_m}__Hg(#6k9BQNQ~!g@4qKCHOii>{t)L z8Z2)5@%^9Wk-jt$D|rFOKF}mB{c&3x5#NZ}PVZ|MjvsBzq)8Q9Fvi;^B+lHD%4ovT z?!(~oVcT}0s&tP+Ka|9-QW8tEMlK?)7viBGnB+3sRdU_@YEDDC_zujhyabd7^&THw z&*c(dRj~#cYYo|GviB*BC`yh5^##QinD^G)$_XQOna@7yp&I4+8t#WmSkEt~Ky6_& zt+Z5j(ON`gzm$z1*F4t^I_0*Z!E(vR>k1M+){{g8g>pTa~HA^b2vjo9N8{e z+6$uWx_A4LK8tLT(DGA>A*l;e9bqti&aj*2343}AyLz@WrEG)4Da+J`?4f;t{haGlp zHvLks6|~<*{rue)^kZ!wZdpxMg)lw~izZX;K4mO&R8am>5*+l%p_ont?arCUY%PB9 z@Sb#f)O%rFDCB&B{K9qVS>$@KI3%mH-gPo@O7m)cf}pu)LwsSXa!**K(yZ#_xP{OM z?|MNB!G5?GTh=mIC!;=pj)-Ua>aQEQ;=FoHHxb&sEe+Zwn?#cA1*6bWI!4val@>fT zTMXrs#vQVYp}{ML3GJC8KC=oJsUky;bnnr|bYE?2EAtNs3kPi&6dX=wim_K{|Iq*iLiXiO-#D9YKh) zuS*SIgNv$^_QcUlt8#_K`Jl1G0R!1icUAUX()MLYL{1!A#rxGo6{^{IM^{Lm$CTh~ z!5XQh0Z6=r+q#BjWw45F_sf(0Fco~KjeI?2K;7oXQPsNGFo9lwHl>8~58FNb!1tbm zsJ;_p@jS4XE}k850)+~`cV?1IQ|TMeT63M0q=U|;F(+d81p~5xs+0gn>ZB|CPvA_J z|DR@CKhcgUOu-J{zNR@B=UK|{G@lKeUxSP0y7?Ps9+w;>`O$M;hPLj$!^%t#dzrmE-0N_9ZDZetZjZ_q zjE2rS;v>ifoq8JILWZf?pg4D#HhP^;AAQ(n6MjE(50q;58ZOq*y(RmgHB+-@iIG+b z76W;6{FM(f z+^M`f%uC4tRhKAb1Z2bMsSYqI1rF;;)hVn&0<%|ZePueA)rm$vHoxAu~b`-Pbk>`3i zK0ls*Jm5;5(|6p^?cC9vOcWIK!n0WZF$yc?STNcl81(p{eZ6v5=lTaV@K9K`c!cz%{`?a*L~NA*OMna_|9|8PQe1I zE#GbK&~>BOJ-li45m;ZiFjSojqo$01&7mVz62YWgGb~7GG34k@kYjMY!tE7Jj=6>e zHCRKCv8wtVtHpm9t9@CK{z@MA3;8amRnPn>89vn_ZPMi|S96UdguALc@~-ee>^!pQ z*m155zprgT1+VoYuAD!xerK7WT$1K>a<}W-&4WWZMoe@%n`la- zQpt>J9h2I3a*|wdd(8Ya5vIR>loj~hb6FfYmLg53Eu~i7Tqli!9LY%x51h`5L09@| z-rlks9mWiL7sZq;>WvzR8;``da=s$tUiLfgssA+Y8KObNcj?VlNNUZJ1jyMr&SD;& zSYesJ;DQ#zP7WEKpedJsIMit{ew;m6VKJf6QFLWU!jPx$;o@4Lg=mW_DG2k6dZfQT z#TYw)WovuwPB+uYx1~Ksy0nV8I^SawfZe}WXacpW{q|KIO&Q|q9Fb%@ zd3k>5!!tvT?(wkpvvzary^UkocVMjob=2$-@ZICWZPVlURfo2zvW_dr;85@%T#9C;trSH zZ0(Ee@;@IPj@7R7nJDGr2&slcm7>Q;KFo4~#o(*fahEF=o9KtK`lRX$j2c#~CeG=B zdi_R%&3)U6@6mNtgM~q_Vqu8r`TZ8u?GaP!3pho)AWQqv*(kMYH~zWn_T;8GqY!7G z8vBYGMNEL-UlR+@4kWy782h`N@-7rEVKM9}ua^8bn*Ki_TIEks00-+xprIIG<&LM7 zya+cUu44_2F|eU`XJ)kW24++Fp|gGG9X{jSMT-#hy7i95S|@y}EE(bd>gJ?|=W=c3 z=kZFtn-+Ci3I867Fv=&7rdRC%e4Co{#ltD+J9E%RvrQM`+o~!D3AIQ2X6fx4Y@S>( zp z&y{N z5>rvFmD;d}Cl>IfWM%T5h)*BN3|dKkFX9H@_jOtp5YNt$K>~T$KhMR}oYBtGbG*E^aCI^#d${plYclJajSx zqLW<{s|P8MCg5e(_;o*x?c1Mk^)1Z%5FT`qnm4t*sH|R3SrovWme87+CctkKlad!J zNn`D-77bnJjbr)}{ya13W}evzb9>z@$Ruw)swt$zvFR(`lWyD29Vc`d7FDK2l4BDFO41N`nrY&sQg(TmiqC(RFDn)X|k;&g;ZPFW>V|)g#UCPA(4zRcdoH z{J@(#LcLXa3)n5;3y3u;_yMM66s{S{b1~>7Y0HXL{JtW8};nh@Y;pa;P`X9@68<3o=+VnfmZEA`-R6nd=2Wpq3JzL!|Wj52QR+p zdEo`%-i2e8)Y&(2SH(faR1uNBj#7Wh8d+Yin%ul@5x=WRa&H4JI#=UhyGxc-Q%_yD) z4}L3d8SAjLjbQ3ZzY9j6CDH#T;cVhBdEzk03s zuk@nZeu%iwqio|TI%$Xz#Gt2!QO2S-_Y86Odv)PYZ94L56B^35T~jm*Qv~>CN1HVn z;&O@sZp$z;-#Ti&WG+dy8YtKH3PesKUe$H6)8ZYLaG8{kz4&`lIIY_})W3d$)yMRr z@!3sq;hO!5-!0mAItholN;RusJ8^2LRt80Cw*(_f#8j375nLdRY5os6*`{Qd(-iK- zjM7<6u_FyEM23aB$Lf+-ip>_mmREurBw_Ol7+82RO3@i8X3JE7-B*=Ev@vgygSJn| zJB$bOcc_fX{?UZB#~%(o>$K-iP>M+$1Kph&3gd_KRrdMa3?7&0cVotkdNWtwL$|!) zv~>F6LO9Mmt`iB2fo3}7cY=OT(*Aej@Bj5o77Nr(Y@V-qlXQ#SxeIB>Eh`6DD%}Hy zq~7Iv3;daG^^lMK-s|7#Hst@1Za@7sO|;(n^UVASkXXcmx@@OZ_D?`7+3O!6B>z8| zScB;N)!US7A;Zwg`#TI6VAbzG8LIxjF#I`N{=XguBasl&|Jnuc{|%Sn|DB9KDINbX z44xC>@25(1Pqy-lXS%}4EPA47CUO<2wr48xF`vB|Vuv{`Pt=U=;DKxA&y?}!TURoK z&0q3;=ou~hzX??IM__nufTjDJ+aM+IpFl)m=nt|`ZoNUi8`oiC4UeniU8KcnZ8F|} zy~TgwhzUqaOBSWn+24L2j%*$Se%B?p!Iq%E1001xKSW8~3_Q+ybqMEj-+sP3jphF9 z^WD*KWU~KTKf>&-hoyGe9q(bMi0=*dBstXMZ1!#Qyi^LVk4q3Q&#ZJ`wNv8^OuM9~p?$ z-+}P_?*aDv5v${FI;jn{XeOR9h?M>?S8aW=nVqsWLE?V8I8m(AG-qY>_lF3c=U5>4 zi|XN7kbe<~Zh??Y+dluDOKJZzmwHYJ{#*{_s|2Ly(Gb3Cr$oSwn+G&3-SrTTH3AK= zjKE@Vt_Ey$QJHQFc-(pb6!??ub01qL{`kc6x32*n%M(J7vY#Xfw)_5z{o?793_i;a z>2D5CviN`AMcIQPI18Mmk+4%oPP*cb9 zTx|851+K9-)4-R)LaWveor~RcgnWCVF!2f)F#kz!6bX4|I>>z}rCJ3x{LD;fSk1fK zDxpH2!Gv&r%E(f1^5-}mb_n+BDyE_eD`$EDA@*hphzkXQgYmUW5N+-sDA*?Cp8Lz4 zvpdse7VUT{54@H{o+vt0f^A7Z2Ka6+!?ceptri-0>dC$N3z5eI&vff=aYYt3_kI^P z*)NcVjj1Azoc*-}e7$b)4Vp|NNtD+uSl8zAsKBkv3JQGDEi$8(zk@?sC~QGQ z1g11&6yFfzNTJKymv3y>`bxgrt`D35`_L)Irb`kAv_74VHsoCn{KG|G$fijXBZgM0bQs~zxOSIMV~_^jaM=al_I zHp;-)aO{be-LP*Ujd=s={ZQ@$vG?EPzutg_2R%JiC6sv;?+p)Dx)-arpUncRcF zx0d3w6z9XWY2aFq4+5^W5i7i#o(}3cUe1hfeuFB2$Om{Voii-}8?x=u}fc2OOJD z`gZ>-hpiOc`#VL9>WliS8{o~^o0+>NRj&ekQWKN_i*uaxV<)lW)S@5YNap3VIqc5( zKwG{%ePx5Uk+d7?ajC!O$YzXPpcF}I$c6{vE-f5m^DlnoEz|?#tA+Dd+zCM3y{9J^ z`o3p3o&M+;OlW+sd&hR3%kH~{iu!Df>13*xW@cnwtES$Ac+QCPqfJFl!TVYo(#<_8 zI&OHVo)%NGi?WB_1EcgsEX=H zOLYAcuV`R4b4e!dFDM1`*eWJMPM~hn+sx3EBjKG_l(@>dgO{0rU?A;;A)EYX%o3$N$?)~ zCk>$~(~5}RMH~C_3D%EiCQbRM6h=&;I$5_$S6r_L_Q6IW)85fuRxUT?yV<1%lxB&F zZUqSH$RnqlD+E$c&FZYRwb0~&ow>m|{^JQ|1jlaq=nhz|i59&z&6NiC&pn%``3L_u8ypq{n&r!d|zZ6(FZv_>K6%mr$>WRxHInf`SjKC5im6` zz5T~sZyT_6gYjZh)+3>|Z;kpJA!ystaAi!o!bmBZ9*gJuz87AKT>y^S4&cMBqf&Kq z5`t;e>7w5WFo@;*5*Lwg6{5eJJ&*%gZ4di{2b@*Zd zyI|#U?oL(9Yb{``K{+gc^d*u$^;xroB1y_;EZ#RxmueqM+jl?`sC+ETJzVRsOW3B4 z#e4dV7|6mvZq67Y$a`p#*@_$^N&WeCHV?DL50Xhv=m^Q;sE4VUaCQc^H}=M#&M zr4rhZzzHz#_yGgL&|k?<8%&-^sO+(5iHe8o)~iI(smi;fn;*2Ztxhb&*&Vzg8~oJ0 zYbnLybC5Bg;+pKUMO`6vk;y_>3y-1wY+o4piF=hcBW30#u4Q!?jcdu>jjsEVEu2{u zm+Me{RJ#8Cuj5>Kjh%@tBRjIZgB>4q6!$drE3+xq-jW3Aey<3g^mePwRG>a4y7=OQPf}m z+59;&KOSI(GuWPB+t5f+!-@L#WN`{Hs4ugyKezf~-ba+eE`Ck5*=_6lA;ELCsytfLcl8Lz&t`pfmZlEte`}h;wBRhc&!+tG%1zdLXWd#;yeoW=7XKnooNvi1VY@ zJV>Z#>&wLJKCsMygpNXW*zN5Cd95!+#$;V1=V%r14EYG}S~vm49l`-zpx)9~pwP1< zd?VKtr87vlI}1GMm8b$vyVAe3#L`{{coPGYSs;XR01}vlJ&WK92?SLWWL~no0V9Q_u#l6)+8AP$X&<#@E7~VIWVCAvs+K`aPP`v|0?!=u-xBP(+YE8U7L=ZP$aW_wm!W5uub z;oHU8h7lg4S<-`TO`D^}1_w7reR1*P*Zp5P=oB}sCJLtGeXorMlOx3sUmFj+jyJmu zz647cJS^rdT@Yt|DZaca(R|U-6LV#P9Szn7^`3Ogv394254TFBdanelwDDkXVvsuw zUhO%P0rg4m%7N2^<1eqn#w1DbozTKHYhbq$_B#f@yms(qssE$3I5wBr)N*S-`qk`* zGU+T(DyKqjHU^LqX5sDOla8xjA$CocMn?6+FUo+B#n9EjvZEE-fIG{!gXn>+eblSv zrBJz|*G^H9^-HV}2zAvwtV^D?>9veeeJdphp3$(~(i0A+9N)A}TD8Ss`cx^aTgBuW z-Q}>`xFMkf9Qcn7>;HJQjA{N9?tgHxck1QbU`zj1lre{@7z?r*##p1N z93q;uJ$C*>uhf7dR%QFt^;h$~r%8OZB7vo&z8mU89BZUG5|;RMpvIbz{t4XUx}8MT zI=Jo;&dFQ-z%q^2_JBQ_c@M!Ei}?gAwX;IV%)8)03F_v`wz|kd`<#p=_8eW;hm_zN zu-nnjAf2!<>e0lBY}UWZ*Ctl~`C_CG^vlWCp4P+lE)Gyp#pX*`ZsIVPfK6h=A*?%k z|KUzapZZgtKD7tL*_3Ii309anQv6$5JWzWLWC6#<%5#j4c7vTT=E-w#@vN3y~L@&Tf(m zK!=E`AJU6NppLmLdOY{djF5z= zzkrx;g^$7{0coDmM4U`n9;Uzqi!NBEaJ!~ZWhj-Pr`q)?d>C-LuVF%)s?sRDFOqts z?f}aDd#PgXjTv!mJa5L#SGufD5yRFLRnh{t9ep&j6AXx{zJ3%iIijmf;^oYgfhq7U z@xHElpE<@(IoU3PiBMlgsSD_S7o1!i^ZbS7C(g#Sq4Xia=C)4CzY)<{7_2k=L{7R-{&N~q7t=&yv&>9czxIBn$>f7L; zf7AoaA-`H<(bU+xEYKV5(J;jt0c@jdgL~XoGKpiWK6H>t`aZhy5e%1WneAQY*BH}_ z3JaGi1?>i7J#3QzRPve7^|P`;UX6W9nC&1>&E>grl-MU-BMDEcBWS?>*A91)Dujik0zNfi(zPiwpWpL{jl+w(Q`>EYPbl zCqbWfiFxlPN3r14iFzjON!OAn>6xC09-Dji-sUh|7(2tmt_2vLYLxj$UrJv#U;5DH=vjyPx7t^un=y6EzsoLY`JS+q^w3VQ+c z%Zr^}V#lDr8{}DD=HM`&C~#R7+sHbug;ZeTByK{dc*+cxI13$_%Q8K`maiK$zh4_5Bw+y@ciWUxn)GKzH zhe(@(r8>S~=zWdpC0;4HK`er>Xt15zXtBh65pAWUJU&XOKcl$RlirN z>1s916%38&!9guOXXWGkSQ+gGVKk49BK{d?US)+{xg!{cgOM=(1&B!w_b?{k^RYssmI z@^UXcHVg_;mF~o#0W3DNe0ed5(^AE^T*BN#2KJ11GhHz{GM5c6(Usen*iSdrX|p3a z_o#s06d-m(Jg)-ZjWMehIiuBw!Iy%V9x9?g==}0j(b8drjoA|rxDF(g#l|k{7Osj0 zUX{yrhI)tWb8Ii(+S&={GYHmXGm>g(=@hnIGjzWfYmZ8GA{kYTQJZtbnhs4i{zWOM@TaAjA{P^>^< zFoSa?yWt${8B`<%wW80&;r@PlQ34Ss)Sz9 z^G#By_&D2-blT88E&tYMdO|h*biU5f3UPGtuEAo|I&fB5ti%TRm99+SmS4JqWkGZxiAp*w+9>{ma8^S} zw(-5Nh19-l?0TB(;V*%ji(XmxgGkPNo;F!VE7DMb(tZ&X94-9el~ihsUAxe}*A=hY z9kHC{^4rLJ#;?VM$gmd9C?BcK=~OtcYA6gTcO8DM>~DRv)&tn@8$9F(G3)WJusqPs z?e(t4etVQC`y4}@kW3(#T(;oe-oKLcEfnNb5D%4(;b5mYMvfvzu`o$VU-vSMaqQza zm$9x#L{vO$*3EhPWS)dCSrsTv#hC+mH`wwD!c1AkpH zAqB=bDYMj^iVrVM=aM_a1Bh}&@CzJ0rCz;1tl0!mh@F0VE)s~9AMJONOwJ0S=i)Ys zchWSyqA7eY9DKHaYVNJqNW=OL_mWzx@jZn!nTEU8zN6%mf}BHH%p|k0`m?n}+Ye*m z2r$9u&_W7}%SlnpPF1S!`g{p9$yB1QbES8Yp&MkVn#0LT7{uBROYvK%Aj57a+{ziK zG7$TwN8Gq&5;&+D%&Kf`(U@yXktM&bHZs)1pt!C1Ly5@(jl(_Hy_v z0eWeopD!IB*l)F4^pYeI_bM2Fbd`+`Bl^s5ixU=`?ljPmD^R)S;TK&Oj3d-$BN6b3 zfX`>=2x{?y|8f0ob>4t1~Y=dtS-s#r7DP8hyAUtlt;o)xRjrYB0)j@7IgT zu!*HJQ$8|wy*9li$SjF5cyI^zyzm=rmS zWP&%j^^H9VwSf!M_|;o&cRJ>I^9sK9JuVmHkX^!i#CCEGY0LZs&KYza%W8;foGKlA z3Wm+->M3aGNbhWPPgr$kSlCN7BaQe_fe48FOXZHFWW^x>=mLJ2yL7of@k8L^L;tIb&!; zi$cP|p$pHZ+<+J(NU6UXOP3IHf&qA1~m!8EvtTYNTnHm@iv~t*hqUXtc+naL* z%a4&U$|X2NJ}ZRWo8we!qBcj$zmL%OkQLX1uTh<;9ja`hvfr;W)-$H$vO4Uv%mJ=C z#Eveu^HiAXfL%TpYK^WRtpj6Y9bW+37f0E3D@z&}RZ3pOB7CIu!cGgQkBg|7Zb@Vt zyJW^Z3oZzs6CXAC5M6h^l~08~nHWk|IL;Bxie4I}gJQvx7!;rUSwf}nzEr6zy&r!C zde>|Kr|By3^M~|9HMs*T0;I19^Y}I?#7SQUsnx4&G}J0c1}AYnb#b^48p7%G6Am*^L?L+P{pWNiTKN`8*j z4BI%?fV&k%9yizB?Z<18sW9>fZ7lKB{WB^Y~C25*x3((S)r8lE++XXm9bGRn#Xtx9Lnu%)nj&_nyC&~+VF zqm})GW>$^{;k-KISG+b2{rL*XuTK-}2r&(vMI)|`viS-Xi|#A+32wxNnw375dDC3K zUpR2Sf`wJ8x0g0vzc^NZo2jaY=t;j1#$r|73pO};SE1xUwOq4nc@lCb$=beCWd2qk z-`Ps!8Zp_LWqzS-3a_qKjlI)n>8|Z0nzi(yttgERzq&p)M_sEu-FDX~xm5Dp7-xS} zBuhu&DX+qO{bK6ULA`!^{qU6QQL1yg{}~$5;UiYOSVj8Q^X+0Qkd2FK{O}GtM!mXS z?NVw~V!gKO6{((iG`&s?jlnRj|2#5hZ_#|jYc?~N7_g~o0Sr!SSdp0DS&s2x$=kK` zOq<@wSG+boj15O4B_^gf+fu1FS>3MI+UrTL+x${^Sw^@E_%+{6C|$UK zsZGM`@A1%EW-miWs0sxLrKP51A4iPZFSaUXeIkF!BZmZA&jhzrlo$ zWeEG*8y&U_t9*@Tkx(mm`CR-Z4#VOWuY};~1gPpLx8Vd27i4NrTwi)QoAm#BxQ-I` z{k80ZGen`2J43_?K^MaeqdVGNB-Oi5c_<@qi}f;VTNu7m>$jnyu(nv+Qa;ombZpLL z0D00?pU%-&hmUg@Vq_~zb^?~f`xtiL9PPmu^;+Eo5Tsaqgvqk!9X8==ET-veL}h%n zoz>ZdqjhB`959l@*FWz0PL0a)S=9Q~ggvjC@kv@ERc*1FpIt(@Fuv6@3gtO@iSRdL zoMJ|thpD#Gqta_ad8^wsrSt&4{oZL{^Zas^$wQ4Df-fEKH-^4BK2ZXCeZ0Qyr9rt+ z<*mjH$9mUM+qpYA014#%2K9!j-R~ajNQVliT5THCMbCU)I<^KZz<8)3#x4n@r$9T& zMY(Pr&z6ZfH3Td`W8iMp+8~5IwF}N-Zg^}09fPE~4MQdC2Wv{W=3+HV^vnMu|S}7*$IF`ppw7Wnnp^gHLlerUnnJ|76EN2r&S7(9dF0%duZ_ z_{81)x_zT$ln2^psN`BISd(v-FLvZ>hhJi7BCB>E|vncXSK zdQu`D`W;^|J@h+64u{2-U5`#rf4z3UL?vl{)hs$K{lEh(Dmeg5K9)#I>47WhWNwP| zCNLsfyUvoi;(UD5c`Ov32A`OE^DI_-lzoF)GMhgtsai+j~^7^^ybD;7LsDoVT>-vdYsa&LeyuMF>h+h zZHR~%Hm`(C1LGaJy)XH0&IW@hXlm={IG51K+EJN0)ocCmYO)b`(|<^(+tt6V_&T3Q zEzJSpUg}zLDk9Ie#w~w6vNC|>bDcW|sD=A7BbxJ>b2d<8Eyg5GXM7fC++Q+Dn=zQ#Y7FZPOeSD2->G5n__8tG#y9u9fHo4jTfC!TLb1-*l8S)3KRL9^E@$M# z+{ADOA0q|Zh`dpd`Z7Jal#sa4Q79NUbo=N!p*)!f`h?+iybOd8EFJa3^O^HEjD2Nc zB4qnrW{p&i4#lPTi}n{>vn-s*d$*D<06xnQ&j6zXe{$IE2UoK#L{=#o%M~p-)DBZr z_MSQ;7WoD7~!d^tHCtU5neXh%8 zH%MpL&Rsx#q&pXycDAr zcxJ@_vEGAZ9%@*;I>4y_gDl_m`o`3X`l*@%8$JQ`TU7*^CmDb8X;@r+bB4c=V&L-r|R}Q zh=>|G&0J%}gi}AUj;DkkS=c!n6d-~6dEjx~Fq-pOYtXc2H&%?L=?|K-sNm!tgrGD;|IpXXfb=qb3A@5<26IJ#JLaX|0~}t*P<~lU)QH z(0UbIoPz4?Ryp4-B*grx)eoy!QR(}#67f%;56Ep$hB;*_Yp~a?R`#e=_k?hJ@(uT_ zpC4xPUC}Hjqp>7$+V@z5_ApR+e4k^zupJSn@~NeGKy3&}>DEo{EGhc#&b4qFEy1O@ zOV>+FC@T5RNKIOkM6y2UnjINapaKjSML<9(5KGPYj(`q~ zuj^@N(RTHe#XxL$H;@h&rap)O;{<9nLp>PVgen=Cyd*_dGlRTT6_|e+&S_XS?6eD@ zh|}Y4E5E7yQ|yLdi?;shdGARtB!;S$BO0eOiasN{@QXl%nZJTW^rLN4!_z6^za0V$ zKu366qR_mlP?s@|lh1+)i>fYuEW_$;fh|gL_y{bk^yD29>J~i&J2hK%<4me1pryaZ8%R7 zT{&+Ubm{vol&Zrgb#?ncMg%EW;18Psh^j2X4@bd?aq_Q!NegS`XLXT1u2)vWc_!2x zn1w_CzF+jYgTU+9_p9Kf_48z}h8Ko`xl^^|W!QqCCvH!I1`rD2@U)vY)Ex|F8s9}> z@c9H21hdKMUn*PAqHn!yJd8Q;4r|gWB~x|gd0=+EH3xSp2F+;KXK`GAr>PRn^A#x$ zFFY*y#94P|Qh&V%Snx>tyaH!*rH26Y*L2QoynS$As*dkKiZE1=CRK_V48lU&FClC6 z9*0|(AznOK3>wB9d@~PkW8Muw+;gU-HG&XQJ8DN0wuPtgoqb_k^HFk^Kv!Pk5!q!v zXJ;}15gzl!z9%ad_hdGO+cAjQ;KiaiOjXB{c9hWRVwBkocewZ2vzOx76fAFM{Zo z7Kn|kw`^+`tkq^rzh9~<+o)Z;Uc(I_g^{aFy*E63@2>2uoo&?VtyrXU&&tpAr`vYF z$}X)TuFd!S-{^_6Y4r`lTPLrq^H}tZ;4gdH+fMEZ*dp4k-^JB#t8Pz~>ifSfn5qpd z)#%eJuuS6Ob!m4c*A&|-`r0qL?cKx(N zM7458B^!f9~W8>W8DtR;Y*C}al?KvFxVid`_wU4 zXgyrLGV{Fs79#BV8skL2kEx*rSpeVy_#}Qx?yse)?R7Ymv#Z=Ls!O;}MA)F{Vp;PV z)0T;_hUG|;(N3qI;(eAl`iF{f<+Zf_kD8)i2;{SihSNn8{hp%t&e3^o&Qk-&@6r0l zH1rc?QiJ6JYfCG?9ymVAb2e6E41iVW)xdHsATOR*F%E#~H*uCLOGDGv8GgU8>|(+8={O!O<#$G492V;*3T9dHt3wq<%w z(UiH&wyf*f;#};A-@vz#9AO1;tpj$9>I|RiI+@V|Ziu$F9V4QDF ztHICy=z+T|%Xd}20S$#E9$jag(BlUuLQ8{$%cnQjb-p~jfR%M#I>cq646Dz!ES!Dt zy*eE4KG`Y(Tc?&@pJ@4G%qZc~;G-O((^qc1sFav=s2nCt`Vbn-sd2+g{Py>j0&G=k zBsUF`B|yR*{j3?J0IkAf3Z6Zkn$df_Zli!DrEPf-v03Z5VE{dz|7MV2qXM+k@*KQo zy>pKxO}xj-hiVqd)LZV=aB|EYdY$Ez%uz)jL_>NvJeIVHPOnZN4x~m#92K| zbYmR(R@I<VTUNWw`d7fndXP*-A5Ys9NR&8@hw~Ty&Z_(jMaESymwaNI zF48q^qgKuUxgbp68n9$iU^X0QtC9M)c;lw~R&uV_b5RS4Rdw8-67RDwY_A(ccZqT4 zyUA}wVZSsbTuEzY{0>95Mh~vUE8Rg}_HVNixjp{zY_-Cehu?`TtDDIQyQ^0GC-lhH zSaC#x%~;%AJDsl>Riq9&h^*p75=>TPm`Q!x9vq;g-zBjB))p?bI#F73U%C8-u#{TF z>{b9wsI4ZqKpiODQwY_#we7Lfb4cGbg@e^Cjp$?|<3DL-xx0r4_c~@n0F$gFkyC{= zg2i<2#T(WLQ(8qbSsQO9mZr*PeQByPzoTxE)XPT`<7Nhu$}Ea3!DKs}|-19K2kLA|`#5xBc zwDgwg7QvIs8beLc0J{Lb<~-g=3u458KRB4mJ`_wI9-GFMCd*sRq%!?{7n3D8D@MF8 z|IymPOu7tl1E2U<9)d~@Oc#2Oc4on?%|^s|IYyIZoGc~qB)X7gd_Vmxj-C7`AKKWb zPu%_+?gel==fWbDlzzfTq)lx-kYR_BiWTm2I`AL zpgyE#@}GHVPrjWn&Nh-M89Vf-5Z;$@Puf|RdTpN{BC6!Uim5zwtlAZkEOlh*0SKt5 z)`^X0E9i0yIXL;V9%?gJ!@MN6eaxw<&Yy>@oMWlm%EI0Y!?4?Ql5NO0 zqMVdqPmbKkLh%ozp45nyKVDcqM0{Sb`1}JjXt>_@XEdB8BUB(stt% z%-?_g)7mPVFv$W=U{0wlrXQg*txB=kv+!zul-q1kxy|&ZtAXNCqN`@)Uqr5`QG*CY z-i21a#Y7?WNqbm)dDq8g+eD%{a856OUllW)koYSfEps>dfH`aMi7;hjrT|Lr@LO$$ zv1-!c5W?_f=D^ie+0_HK5Z)0^whuk?<(zkWE!v8*-V0+YV9=xvK+t9Eq}nC3ItbYa zOjj~1)zYFaheruXa}T9jDA`bR(-ko*^|D2n*fNBB@e53jc*2 zVp7djCexVXl{RNkrLhm$svSX-BPOd6ny_cIR^smDcBnDM%@@h(5v~!5d)!Fo%o^{< zzfJM&t}V}f))tL)v*UtLUE-d3b06n0e&Hl4v+CIH#2AtOe8nQwPD8OS4{4ciHFxM{ z-rAG553pWM{15itGAhcp`~SW~3_wJO(m_CF$RU(eTDrSSK#&fRp+lquq`ONxhVBqV zQo36Z1cs89evX0by6g8}zjfWudfq&1{olA2zL=TwJdbl9``CMbKVRpF8RE2+x==ac z%%GBcWHj`@&It--H^rS~ga>!%g9d^M5BFq1XA+=K%8ixEIGqh3BCFbTLH9;W0JQc2 z9vs7CIN{yQu-k4r`L#{VT&-C{RQ~X(avELYWJ#;72t1kd17%;-XLwwliTIP$5=vhA zC1lbyZ%jDIL2Q?A1&s!G_HotEbQq#OMJ*8(sNO)fz`UhaDYe3X3sa?b?N3_0q%l#W zvJywp2k$M%DYxav2r1HVA%0Ue2F%y!50)LLTe16n4AapL<1SAGhiYukC)ijjW!kbl zdrDQKE1IgArvuJaUg!M;C91P2_b}7i9)%!B;1EVJsrB#W>}URdpgP{MR5$Tyw*YA+6PvX1b%t#(FxL*Q zF6c=EBM5bnnVn`H&SA&&)0*c|#CYJC;is1`*nItL@H?$`rIT-9)D(GS&Z~#}Uj4vKYgSnOn$bl8G~hQ*p)^F^47sw1m1&J;ve=$0 z{iE-yJT{^9gZp9zDc3!IKVpi`-9CbNvrDL2l1hU{`zt})z1&a_sQ74zGLY76KWzC0ZiVSLdBG^C{;(X`4>RgY@eNDL z5dX}ePh-D78A2Lc?v8<=sT8FZYu2K^R@Z4opzwi~?Nb-fgAbb<;1K?HFDJ@pwGBc& z=>M?iP>|Ck6=-3z_27A)o&xS>jg>CCnopgw6w669oI|BZ?TZ-!?I<)7ug;Y7s1aQ~ z6(GG+qv{TZu%27VIGs_84n=D+uH*`)0JDvuH?{!nfY~p`Y#!l7g@O%=ZCFm%c;BW% zbkcs5Ysh(}y@~Ba9!OEX*Fdughv1tn6eleK9mRGhe5cfR7D~&A#{enK{i2?EK!9k; z(Fs&*eb}dkxjRkSUsLC^d`uHJ<4Cz26)oYX0X3_+P5G|R_}qaM6`@IJ!rB@d9_(|LQ%_MDH&+3F2wi#1a zRSrCiQJ*4>VAI55bK>(yKA9#5qTGue0^G0KDmcs8?g??T-4pCp_iKF#Mo{0!ryp3m za;=P$%=Smh&9*~XA@x-R-@Y}oi}LFTf+Nn_ld6Z<%$P4`Z6PXd&p<$43_>xVuQOtwgW3;tZy*3{1Tubp>%B9y#<9KaS88Rm4_V zC>F#mgF4da$=*^WILPZOgEaXIaL~deDWeGT+ew0^nsW7oar>WU+6XQxYXWq`(7lXS zWj3hroPQ4mo6qXlsv0pf+rUek4$9OQRO~jiiBNN7_0PB`7@RdnjNtsF@l+|j4+w6e z%m@^&Mrxr<2@D>nM>46&&zQ26k3As`41!zANNqhc{*2XIVBLt5Mugij;^N&Dv7aN0L|B5QAY+zV%;=A5e%`=oiS!f z4V1}XpmYb1lev*A)1o@0bZqrk+gIa~v{7P%&gq=9dqGr+zRrS)b*N_wqG3|0nA@{F zU$DOa}>8CaZ9H?oxp0_dWi&&*8%gS8~_N+G4ux5#3_gSp@;s8_r@Ln%u-eEUYanZV-ep_<9~QD$`XL&q!fnu2v?of$}@Dgh;w zhJ=$E19f8|xwa@ob`=RhN`CEBCGqeYtGTgSpYhZX<C1psq@_`9@#6 zb9bGQ3^Ff9kZp8T1aRfEvzIl(3qR$pB&~qPn(S^4JQF8!G^h*iL!; zC|@L5vgQ4Yl&6@xQv)o8V_ZQZ-d&Q(QkK6fzxiaEN@*_uv0hxzyQFJQe5#_g+~k>C z_Xce>)pa_KRA~2c^EOQoGOM)$NldVNZD4)*U@dFQdloDd2w%-5QVG`;oe$H;AogY& zN_kLTF8*;q$-)u`-q>o%@jkM%ecf)$m)+&WC0HzS5ypmTDyUn&vA#a{b|hEY@XxL9 zPId_TR8?fdQ!%oFfkR1=*8#Wp>5&<-vm<@@5sfFcwZTOcg@i- zU!_*4;6QE$x=8v$2j^8xcoi;W^?aVV@#efNLZxeU;W~-=9aR=U8S%gKVP4UXV_`1i zDOrL)f%f3nZ5{ev&uvDEyABOwDs$by;lUQ$W*>3py z{kjr}IIaf@qGsMNyp}k-hqO$MxPB?i$c6w-hI)tzi5;+gXv6<^ItJSI0YF7{gwuJ4 z`|Voq8({|1XY8-}Q707vUy5R-u@1*7?*i z({QIxEw5?~KZiSmM8F(~#7s)_$KKp56T)Ds-wo@i9aWU&F~sB9l^bgTo(?=8I5pg1Jw%-1K#=|@UbJDktS%6N{avoAXYDWmE~_hw#@rM8DCYK~MSn9Snpb=IknUEn9|BXX>;SbQ79`zDP`JD3?Dt-DX4F$*M6we$X_ z^FP}DVgbO{`ULS9%Hj+`%E1Cr1IkX%8=tDMwk!64f|ln@%Qy>VtD=VIecDk8WT{gy z)9QsV$$Zc^Jj^=;r(^@H2O!h31Ct9s9r8_%}Jn{`l1n)B`aNvPIy1yi9T+%r7&~mqrmqJ-z?EkqO9Mt@rw~+3q z@G4fOk#Y97Kc#99W}GehQeKRt3i`$rd{=tp2eK!|F zDMTlrX8C%c7Fvidu8h<>A23Ti#KL?C%pQPj`t9;N9pK0KK6|a|M-U4Zr-o-odMB93 zCEgGRek!<+WxW~DlOy@=?pkA&KCiyE)Y+30UA}hVX5j zr5q5|bF7}F+|dI=S|PN%cU90N47wkFx*5zmpDfA1A9221dG;#>7aRiaugluEz(&hvA0vz@yhgIo6l--N?4q#wQibAcc z94&MP0t5~bI(Wgu^2 z0E0E`_8VFx9F_^v3D5eeq`}lVm))&L9@{L9KzPS)HBn8`gT8iu~?&yW7w%l?m#0Z+FaJiZL(dGwkVM*Rf3{S3z{MfgR{gH;M` zq#=I^n~?i^=$ntv{@ca*&uO(k*Bbgg20w=)=)Zl%{*pla^Aj)3TRJ~b{NLO3FMde> z_cr~lfc^iw+r)4yM|KfJ~~L=$)^WiBl|;+W=rfRxxOOfd@qv|juZLbwkFk^M|&iGnMYGVJw{!_ zl0QkG195jI9v>IpMa*JDZxA@#m)|&b_usss{;Apc>vmuGvnYwwf$|4undXpAqyFiX z&r93hXuAVR{I`ueMvOD=l*b)-wO(Jpw0*0XcH5D63u8YeX_AQNJz3+v2wX!k=q&u# zC2Hi2=)co!{$pnYA^HN>d;bl3bJMt|K%u67F9ObSm z|0#0e2%2KTalXjwlU$}mn(D`FD&o3Rs{8K&sBM|FQeqQ)<(=DH^<#H^DV$EW|Nbhg&sp3QXk)lzPOi)p@K^% z!P4rZPLV*;krvoyc~=xUE*{hZ+MGg=`r`et4uJ;f8ulb0zcBtUSjawuo4{WY3g2sZ ztVX;3i;cah)3tHaBEE2RK+xgLv$|rP%4`gkS5SfxK`V35X8S|&{+stS3&fN5MJ%`1 zP{O2JA73R}FUDtd6&%zIBeZO1g=*mx0*A%;#kYhpOFI^$XJ&>0Rj5>P zyVae+(-X${0u$mheOJvq<@;}+@>QrTVOrz3Rjex?7stgFM`jo&97eE2H0Kjex!z7 zkZ+$oyQzZlzZtu!@PZV9!c6R<|KO+;d~y=^cY)*f*FBo1ZL_}(*;#v0d2YRhyt&To z4?+eO@bl@lRSYY6S-NLVny~~DXbd;z4c(vxf@GTewY#PH#v-Y1Nu(2PzBd#O59*l* zeP*~`@YS!+aaU-$+8%sMIP81zg)BKNk^G9q?$G9WKs} zc%oQM5)c2yLGpjz6FRghf=LTWsqSJSvC)_?o|M*=cGs-xq8F>GzTX_3cQHt~*Km?E zXO9ZjlC$66e0{>}dCW0}C7@wgmt8l93U-o{ zj|)irZs1PM9qRhTGL_VBkLFzS+>$RR3wL?Awby-jSb%FccXPyBi1^H$=Xhmg_z+zA zjo#%%T%+A1qnDG&!gFSWok1c-mw~rS2gT(YNYqj=V)don3TpCB{?NhQluJj+c42S{ zK2jHX_S_$j{%#;8%KzwKnhWs%_A{awv0o}ItlXhP-Mu|0PC4t=J)&p!L&xnW!d+{k z?kDXH2pxs_{V@FevX|)^eQ#2C+_q(xt_f$#@qKZdriibb4LOk9mMPIF-Bf2;`ndqY z>R8S(Vrx+<;V^LhY)Q7n@N9H=no;rK>w$C`*UO|(ZDKc}`gWmf@aYM*O6uNK{IGEo zcM>z%y5q!y*dqPkTO<>9@%eK`>Bqz)qiqCD60UMtG8apl#s_Svw=ym^=PylwuqO7u zvh*duxmE$Xl}V#E-s(tGJ?=;sN=x5El+|b@*-%gtzQ1PBwmCe`bZ*=Hm}}X;_jL1j zqYajdTVH+Tsmt;0kvvJs)ge1J5V6Wfql{g{6LbwLKk5qWsLxcTS=sHq;cziIKeZL! zq^`z|{1uph<->cw=8S*ctoB#8y6^TA;vCd1&7nO)d6YV4U%oDDaN8@|BEP6c>#9M#u zYd$)|eGdFe*o0zg$>7$HIp(u6hNnmtukUYJ4MEKz=cP9D-P5MGW(ftSy?AO}ejQcxzN>KC|FIRY76NbF z@v@oSbe_yyE#FD6%TQ(?E8sGjHI-$OZ<3h$()r!2v8b|*KD*c= zu7gr-yvh5py2<+(Yk!FAt+``N$5Yp~H|_THHU{uxeN4FFUu-ftDs(`3c;OU+hptWj zy@2vwS6K68c|2b2Kql`vu`Z3O zrWn7z9i|kjD{S{&Wqs}am(iWZtZ!!{Zrr=C&biK*0>Oxd@=RSFlI>-IM#7HoWMj#B z9Z#(Yq3HTlzneS1YNL7D&PjUk!L`F*_+h9539xfMIpwrqocCZ{{e{*H0PT-!_cX$q1cYE-J>XP<)geMJuJjcvMVW{bG|JjTf4osmJ(U z#;)v~jCudQ|FiSa%18CpAVh@n%^woFnH!XI!gU+Uvwzh}{6*2bF!-wRz}jofZsT3lIUOcGbUW z-oz7io1e`aj=SugUInp-{0izq9`2!qPKnzueS&WT&J+LtdM;K%gE9v`a1em!tN9E` zEvn?q(@)mk)ftH=TVWOLO0A)ur_2Igg6sR~(xbn)4g+zL-xlpan4KoV)?`Gl!AqBz z$|;b?X{|@=+G+`_!IOQ={l071OC5{ss?kziWSnnbt&V4SC!c@z;|Z$Yl8mqY#WmMz za*QZWx7Th%jy48t#2ee!Ymqi&gEe-v9zXeGX}{xV;|P|eh!7@6{PP!GAMiQK&^ybZ zl%0N*c49W=Sku+X@k2i3K1S&jmXpi_-d3HXh1H@7Sk3)60=@Y+qN{Wl5BOH}MNg}o zj-xp%sH;!TPj_oi=!~@3I&Bf%Fs!4($_Em&{gLHE_O~+%t0*rr%{|js>m%7jHpZusGUnTr zAIF4tO0so@cOV6Izvj3}76oQ^Q4`JklO+WO6ZxDEdh?OdN7h^c4eEuTX7NtCL;8=m~F^q`Httef4~ zaFX+dblz|tPm1v=#?qoe9!s7ek%`Xanb*#f!3^!U-A;InvW;A$4YoQa?Ancy7SpwX z<;G{)XH~ix-aAq|n|$Wx)0|36NE@dVquFG$&F8YrX3vmxrwIo{Z*;!q*e}fe%!(c& zz&w;#odb5QPVZzQFYh86RDECSa?42nyITgXc+tR&_jvWs?2U@AA_Y4eUXQO$FV)3q z{X+UtH3<);y^pNMiF(ZwDKxlTDMV*9#om*2Y^)QgJo{E7Y6omBH!LO_TDPB+T6frJ zS`2TFo8Ir2v;bS)^SLgDo&bNLN8(gSx)&j7dy`3fdbmtXn)~J!UxHTj%s7Kt#P?Uc zge$}=qui9WiqgEd=WPx{Nzh@D+VG&OEZJy^w zm**PHzZ;AmVW~7nZSHxNuXEJem*8?CllFYl5D@t~K06QX9|G50mxb<&6A@nIK7p|{ zEg*1?<{J_DY4avmwfEGU)wwQYR~D>?PjJ2o zTsx_yHu&^JYV5>|984 z5<9JFM(b3Mg{=&Ak>AV~J(YHgg5h;li<$WByDK2=*C;+Roh=PNK$$z&%1P;wm%DVX zS*_I<=cJ=r z-Bizz6pBl3Br1yut=HhAyqsUN*4MnKj{qsZn8CU^TEBZqvup9~Bl`$sBu*vE_I3l6 zEPYA(nN?ZfiN4#O6vJ61$PC%lj1c7e#&wfz3eu0<*f$26skqn&{lAARoy9T4kNfmY zb}o$15`BG^^RY$;xFQu|DPGnXhN93Rq53i;ToL-a83nx>t%Yo^;u0i3LSf`KNE5Ho zg(tR!RjpS#YS!+y@0A8ravf8|{}8~_Lr8+$czQd-_9A!8=0V>yyUUv<2%NY7)=dKr z;M-x9VO;K`T}02V=ma;(tdMTJvqY|n>!F4RZf4_sUuwoR3i7^Oso&lj-H(rN$XEa4 zEkXc%bE0mB+-25!75&X0F2DJe-+%GVoQB9AyS1dm5}arTWC<2aL-B1!dihnNJcDTW zdgEA9>lGt*$|f91_PczYBH$~>Kh@&dEzpfrcHPf^kUX>2{V`>-kNt1WJ2eyV@wDrpjY1np zoU@AVR2}`*sggt2D{rM=KIzo|*jWQFsotQ|tGsc<65Izb3$}j2FD$X`3#FYxbs~0bB8vNRvQH6ysuX z{8g3x0|W5_EE5x@yZI5V0YiI&;Ez-5ES1Z51KlU`x$G0ofeZZCLZC;h@OPlWuLSKy zqNyB4?V7QWM}pXy6jv^m_&=_}qvj#{Xp{>s0kolx&}}HUzuHh2p3Gk^o9EV`dH%0S zE&q>4Mb>iyB%HR!{am&H>!$%8gujfez<>PuMKENQ!5a4e;^Uv_HDi*no2$*=MuTnH zEhbpeXb8JgG!Phco_l(B08{#(x9QG1{tWbk4g}`t|K*}`xak5E3U)xmJq{%Q=4dYx z&>B2WMl{m#=eJn5!$-ak^g;jB@W++UN&mQw7d2}*F8U221#30~ zG;2@Mvk82EHEXfaU1GgID@~wF{I5l}l&w4W*i0i_!6@QLmMpoZQ6ggPKltApRX@xZ z+C%qu2W5{FXFwO-98ed{x&H_MI|Kay%RtZdzg>)E7%$xpGwK1Z4|IIF%|;6oXN%QK zU#t7xCS3y}HzX-yNS@tt>KS0VFP+rz983+`0YY%=pYU_JV9eay+_gMpL&->i=JBy5 zCA%Tey;cA#m!^stCo^Y&qob3jM1cEM4twHyWUE={P>fg|>BjLlQqeT~3sn1qC%W%|V&YY%QJHz=wB3%on~yuL z|K}GsNDYv6+x`VWiLwtQ8m!qCuU>(UM;xY(%a-slV1O@Q^TZ)$+=Wk8SeSv?nR;!2 zcc}oX<$YRopcuJ^j1m97J*Cqm8) z5elEIbV&3_`F{PPBIkPpgo`>cZ9a6Y|3^2x8GZh`ciED-4JIV`=WAD-Sz0HTs|PJc z=C4(J3;ft+)SX~iLue#YrziD0v_(b3)FR0s-@Z1jk&u?ed zESW@=c4iPiYL}9n>#Swqgtx`?Jc@ioAKD5?@;|p*8}q63Fu#lW@HQ>Wdg~boV*6Y1$!omMifEkW2Ve0@z)$ze zBCnUEnia*`qpPpX*satB3QxorJp?a6;$U_qluz)gwx&!|CvYaH|Ko*#S1<4Jx6L;J zL=1IVVP&=F04sOgzqpsrMD<~nuT!0Xq*-uYNG+Wlv(W-7s z4&3=fC(9}VWqi6ioXObrZ|Ovxo4EbjnqqZ z55-C8gaRat!sfKZyaCp5FxSSjV=%6DQk$H)9iwE4bAXKrU!;te) z`m)hYCn9ZCX;U9+@aAYGYw{`{_jkDbaW5J~W=lx5P}oxGAW%;Ko?jfiED>H^;Ge&- z<$fc(cWPJgl`fUQPv!PnwPTDlk1JT61%MOOkbs_-wgW8648~GkY#LwPMSyyTfhvFpTLHkqjUNMjyD{`NLmh~Gfct)dl90MmH%!{%qK(F7p*F9g(r*)9OdKUSxy$<7fH zNAtGpN29<5*kr92-$IC>R3oAXt*SnIFj!sCW9VlAGj6^pa7Mjru?prpt6Fb7KclkW zfe53>O2Mwlq}%UW3U<)5vvke{EPk@g-3QH3B9A_-K-~;IQsRp181fj{?zvV;b@khKHmZgutWxeYn6wId+Vrjd0&?Kk30D; z_CUsQINEr&NKCVT49L*y;3e3@P?%vw?pz}`OP59R;D|40Tzza*4wS{om{dNL< z@%IqNW_fyRbujbqB@SK5mGJ@$@iyP}t@|wJLrt}S^?z(bP$8~^Bk>d=k*zfzF3W3D z3)o87Ubq( zWENzoI=cxGb^yy+1TKFPz&>?Y!v_d!rAFnI4DM|Ze#KUac-b6RRh2n*WEL*-*v*mu zqout?r+8<$o%=j@II6u8O(e&jw%>_ts%aJHPqHB#khQpntr~BTm1q}7o|4|gfV~-l+v|(IsCsvt&x^us(5s zRD6ub$+^KfJG|0YpjkVw0rOG!Q`_A!4MY^7BdgxfmXK1K4n!M z7H%xwd*53?sfyF_`#7K;NX@hmK$((Pw10y9v;sXjnaok3QNI0N)6pNDEpKD1EzB+X z!wY|5=s89Y9#$OwjGs(-;ptLE;K;6shLb}9g=GGevZ`Cn>U1f~- z7RKHS{?XLBA04oxaNhGb`ohmh8#+mT7*evDtL|RdXQ|}1ZFC3;i-9t=x z4b#aU(tU6A7w{x7I{K+-RMRq4W8YGFI$?&~BfD8%U_bT7*W^2Y7;R$`+XNc4IVAR& z7h=`58r_^MHr7Aw_Cddu zx^q}}JV!6=PUL}>|ty0jgg(cqq7nBu)+Fu~h}4AT{Ej#Q9N z^<}L3j&Z)8XEhPXKG9y8m>npnUjs&Mj~);Egu}4=-)bzCULRY2jzm+*xA~7jk~PEH zPHkGVdAecWxD&8D)kCPEut|(Iv2ho5*dI4CD zk||^B7MFRC0BhgzQMCi=tO$tO$G|4#uR44WRL!0~y!QHTKEWR?sO#ussOPg==9rg5 z1L)4VM_9$qGmA58bF@QB)nWSi1oP-Sx!Yzxg~-~>D_nEaSeqRIuBifK+~<1v%>*(c zf!k=+yt#L)5E~=7=R9`6;3Eb~KJ1XmeROX`jLaKTjM>%eh83Rar?si-Qmghp-*U_w zikVWyY@=GGA}VZaIjM>in2|5Zr%gW|uUomttI2}W=)Db7C0B5UtB)oK>`|up0EOLt z_Wd0cu^3Mqtt*_KI3x#Gq}UAN+K5rJxD%26NKrpcLF$oPYcTnyXbrqYE-dGfQkw)- zFdSu&8@N)X<+?ZAATQR21V@wx3KW>_-fI&h6?%h~7rQ2crB=zuzO@&pnw8bVfMCpd z1n(GN56iGn^0!o^1wOmlAnv<6G=1Rg3X>xJ3}NW7E$}lcrb-?7BpxocB*$o16V=S- z2K<7INazC!JZ;gZSd1%v%C@} z*=pnsGMjha4Aw@?kPly}0EymjnQ#YacOJ_spbBg1Mb=$E0pGs+9v@Ldsov;zx`v+L z_EVeB4AOPva|XoeA9*KroYf23Qkw=&$#p;D%v8j(f-Dg&tH5h?C^R#4^=4t%a#4N* z)^bhg6T#W5<+x9s=DdlWJVXcBvR7rV5)mlCz)d}Hnx0+iO-}wGcyPn@s)#UA99~l8 z?kSar_|#A_gM5Wl8Y{y5h9Y$_=EEa)X=X$meAnWaP;3O%XvxHyEREzXL?fu5LS$V9 zCBNq!zh?C4Z}!GHZ4l>XD^{k)dy1#`%$88LAllJ zOmZ6}pdDz>AyF%u$}{|L$B}M(aB|M)`Kr#5LJyo;SG9%HBC1wSU#ts50+vq&bu2th zVmFu|o(m-tPRYZ6;7^-T4oEXQI);A`eLo2stXJA3g!H#aJ0g;C^wKHCBuuT)J8T^d z#|l|G^J0_oBR@eB!Mdc(XTjinojwCvXcgv;+s=+Al%Y@zgkZ`z-E8Rraz&NHTaoUv z!t#(3y(PW7EL{HhpqmlbB2C^v5E5bNw000Ks!E_)C6ZtwB*mymJp`!Hny(OdtcKL? z#fDa14?L0kqCS+S9|k;SIe~WGZF7=(%|&4T>KGanv2KWqH9ZFuX7CDX%cG%*-a2{Zs`b z#YF%!&|u)Ecr4VO@)qqdtg7FbG*C%`8OswSs4Ij$A83&^0*OYrl3OYvEP18M zb6Z2jxd+4ygU;ccCc}R8ibi>LAH`474kFQhL-2(?tSP?0qmgB+SkAh=nKH0x9UWXj z6p6RfDo$_DKsfwN76=70Ee30oIKNWH3^@NdJ(x}I4uh%1vYAC8Wl6=Lsf;hWz%agR zR+5MzJR)shdNURjfwU{=-S|^d@AN$TMBN_%C5HNf;#b24?-F zw+rcjeSmbN@gjq1@7sgOj$JV61-59m>x%ra5eRP&$;wJ`))W z!{Rz0nt6l;imjOx9MRrPE6?+teXbj~+?W=Fm4dUuQSM%YH2@y^cK;G!^Wy=yaklg( z7U*)im7XI9?x6FDZeWE}m@L6y zdmr9kAY{HBk>(uj<`u_An;Kiu*{bu3j3A&J?F{NxS5NHDb{-j}!2&@^U#C zI-ove-&Jl@R4y?DcX*q}f8U<(a%XLMN5j!Y~8SIC5xWx(Ls@ zFO+A<;jY%KDZxjPD}*IrEN19;+I48G$lctlR2>WVDL#fMwkLhZrJ%L2jTM3CY$-%B z71l$2+2Uj~eI9MmeqERrN`t(KAifv)TR!~*G+?+rT?E1=v+_H4j%MOcqDoO%MJT;R zx&kCi0Ta#{%vK^9*ik+2pLZK^(uwnN>W|d~rmP%2xj;Y8gW(WDD@r(ljgn64ZsiBkmU~TA3Xj%Fi`>T3a`|6tt5;hQmvlWN28`?W}fWT ztN77-=qcm&_y#O0Ld@}Or_>qn3Z5Q=}Jlx2_)uM(Um zv?!g(+MT&4rK1dcAHH%%G$Gl1)c2ZV(+rYbZZg=442u!kOoSAPW0DXS(s|5?vWVf|n$Qa&IL++Ut;Y)qCmeb(oHezL?j zbD9msV>OiqfKMVQ*3i4u>(9SV;}Q{uFMI0XTIyj+iKXYXupBLqWRrUP__a}~LfK@p z!v*)UryeeU7)foFcbxw)3b5(-EYh!sJRFj*uBKm$<~RX5`yD6Z3INk3U9weu+k!>C;|~55lKSQiE!vJefT?O+ZIzDQskqL@9G*U z{ay^kk_vdY<9l*eyce!7%}&?WxiwEE|FMNC4m6VGUs$==+R=#4Yt`GMc$D8ygW4}e zzCN~xKYrixDR6@;7OBTB&JCJgwpRtLPTgKYlN(u^2QEa#v z{S_X|WUC})?2qmv$ml+XB)M-NQEbA%q(Bl_n2;!xIy}6cej`$~i&CPWhyIK)&F<#3 z>TV!?uf?3}3r7P97*(Q?sD9It>8R~^5RWi}}Oy|%R_KC*6+xW8KK-XPP_j3J_ z_yOtkTiDUyaUg(GrVy2Xw-C81zUvmzA}BynpWHET0hdR%RL?j^$WBt}Knc?Ax2i(q zOyc4|-*5dxBj0O})3pgEi7p}2n@S;{FGtDbOD$^N4=es@|ND>$y>W0~ZXD@e=!C5k z;7L8-L;u*ji}bl3sO1;WSvd42jeMsl z@)*_2ZacomVq84!{Kgk3G{oCJG-o$ql4UrukQ5Ng?wP8h{H6$iR zLF_mE7#2+&NE5wG7d`B)E1YN!y+-X3Cz6lizNHgrV{xMkT@7!+%J!TG$AFzQd_U+A zmRjdcZ9gbsl5p<}o?rvi2b{pIt)m#^19a>MHfKBOr9U*O?uspA`thBoo9tOq+n*ga zUlkb?br%(-SaqZd#%UtZmskR4A4k3pL(Z3!WI@hP6>n6FaI7ASZbnqs{qnZ2*3!|n z8Sb*V84>@kE>Ua<;Az1jZc7JC$(J*HFX=vBtbF!ab~H@;>Rdc+64MkMcGkC(LJ;6T z8@53Si39U~+-1v?R59FFmusaD?nU06W?!$xL_pOZr3P`=wFp{y*>v)BSMp6XK0K1o zvKz23sG#<`HC!*;+os!ZE#!S&qutEAVQ}{}Ym@&LhSP0{?37i3(GMH=%-WhY_Aso@ zZTSFy&YkkB?UT9JMcQFCbbHQd8*JCjFlR!$Msc!N)TraY?%kjn(Qg)xo8PA}E;HO}*qjGo>hCi`OgEB;jhIT@$}Qm<}_% zQMj|8{YEm~lT}}1!%MJD$%TZ_Ouf*7%YGxlf?FXDEISFduthT^|IMtQs@H+V$|rHH zPeI&)Qv=H*ec91vvhrv<%a%*Qs&J%waFRz zaBBaU(@Z=3t$hIM`nE1grKUZjv?LW$iTX|H5{z?214KKFDgDFPBxz5motPBH1S644 z>dUqPI4bW$U?vNQF%ju6&br{tU-STan!mU8QGWVvTI4R+G=l=9WM?_3K`p7Tjf^9pvz$7YCdT z!LeIdPXyE+dVd9FVXY*Q1rvUZPHXb8XeyrM-7|PErAYz%ak}$IA^szs0bbwZS1y_$ zyMu9J>&fnt5^$%r?wGlO;{QRa)UF$H^T{nx->%sw_$ea*am7PD&sr@dkcC5}JMxm~ zlR1LHtgKUzq*{hWs4BrU1?w~6Om;F0t1w+k)O|GznYp8zb4Lp=gys{f#6)^c9d-Ku zfX8B5IQeu<7S7zaDL47uB78E)uSCRALF-+_iOEnPG;UEwg#stp z{mhYkM)b`@KslH~6o}CWxUE+r7zD9UGbO93v_Rq1N17Cvo<&qiyt6J(6O|?{DiiC0 z6~&CUgHoANk4VV`E8XG*astRxaIs(m{8BMXc=8-obh=zxVgZJZX>3t$#Cj>yB8 z+-;1&iNBZD`z(C&08AD{+?oSD+FhNhq0rpWU1qk5C48(1Ev)+zRChR2R~3uWLR(QG zDn>CUnS?jzE|Byja;Y0o zxTc(8Ayo`9z%b($>ywAk#;{L;vV&{Eb&2hHG2GNX$0~I$-C+2O##a1h0PK2o9dFLZ z?~!R1lNCC0S%&v8=P_-0SMS5G_!de!_6(QZ3-Ip%|4}CcK!}7`xv(N^&-h)b5f+_Z zC+%eH!*r`Xt)&t#O-P~8*Brgr>wL5eZ^Bj@(nI*SwLZ!5G$|pyEFg+Eu86Ctep5Q& zbX34LMDbE@sz6;}ikfjka7f$o8I*paHIvzGFYl$dGFv7n=+3$R{1h&;t^WmjpMpic z|IL6aM_5_MiBOTkv0z|U-#*D$bupkajLsg)~U$QXjDpQ&RvP-*eY=${ARd!pIkXRMXr3YT?NjTh7}aetH(&ctWYVzh;9F$w|NO$uCIpkST22 zCu0)_k=9y|yD%{Wng=s58YQDVZqr2&2gQKTYM0*(kD=32#0{_SxQbnAq*>x*DiI=c zI^@o4PtWVqR<$sk9r+bRiz?u^KFCeL6rs~sm3->hI>IxgX1O6D1fApzEZt|AcyI`& zn0a&;h=$E>@MM7kk>{;lVatlW!vga$4!l3($*y44F03f;4=KtLo<0~%@%?g}>)X;X z@T7K)A?3qAjs~IX3KIn01^>{zyF6(N>fO8VhvcO0ycE9C?S+V=gN1A`JfaV&d*OXY z7hA93;Dbx1bJ`5gqKVKk(m5<2ICIWO%HPY5M1CxiX(pdW!ZPs{vGlLsesE9a*o1Mj zKpqbc-QAm~s?UpQg+494dr&}(8j$};|HR;D^G1vAt zLrX0dvJcrUe{z6opK=XG2d>BCo2;6P&J{0?e^w|X*(=UfEDBx*j|eY3zLa|Kr(C&X zZAqDzfSNt84o=HOHTqd{7Msquh${o%hJE*NYT2BD*$7(2)7hb%!Q5T6UBfwt>(NKL z)9#q3gXh3}Gf%?3^mTX7S$7gomV`U-jqg6|PUJ$;opi?^Ae47ESCLDq-i?Pl#7bY! ziRjhT2v0SutmTylX48ZN@~}@Cbm~6Qb#}~KAW%|AkR?Ay& zDPLO(DENX<`pubO{7E-f;=4%x?_CCsG8#J9V!x^ORJtYhicU7$D(Ch!bM~cynb9Sv zSxJ}w;T%*Vx4A(Z%!)iOt@3&BvLUEC;2?oTdjOjwt$zuT7%85I!}vBZo41kVG=Zgn zL~nDjsz{D)pz8Y1HM?4o66k6gOV98d_QV|xd~M}L5hdm7R))qIHacbWBm&!rS*xGK ziPKVry*EM}C*cEJ^QD4gKR3GX>Z&3AHg*-W&_@47GUIG{7~I{qnnW*Znw{FsuW%l% z|MUpExvb59aCbEH<)Vm@3h3#XQ7#5kjqGR`RtJgk4O^W)e=5a#zY#8_3wo5>y|<{< z)Kgzm`*-uoYEkcCuaqV9KFjCZiRp|(qWhfkN&4N6 z)Ez(HrX3}WH`g~~shEE1ul%r6v)E<*{v5j6EcNsN;u`E@rMN_x+UeJKQn&N|jz6sv zirY=)dGEFH?S!_bw4Kap@}*c-O=JZ{boFh!WSQA_@^49p$4@}Fz}$8iyHR_LB6$Q6 zg)9oXChhqSCnvm~?^(Y}PoiC_N6Ato4ih`@u_k-G5$|$6_gkYsN;^ID)`iPyhe-IxCMajDhf9{d^&|i!tLW;=&xMXJ$=J6#-cdO85ErJLOSwS zz}7Er*b*|hcFW1cyY%spMj>8C(kTNAD+>;^7EWw#16n`W%iSkndhI}Xz;WN{_2VD| z;Rof3&vhy1bIS^uyXypli}bO;uElK(GpFvkDn6Ed!a7+q-27yO=O0ogDGxp ztI<*ZPE`5U?xOv@^XdgkbvDOiD6p?aG zL@vBsUI|%YSJu0jaLf@ZR#5j3HqEh@@4cGhust zk)cMwqV8PJWz=95hnTXyW25aSQg2xQ1{ej-tduhXhCg)V$(-D?A1^QO< zmnF4Y{N3+&#u+0Hngq>hMRFLMT$bz9Aj8jMe8JEXp!}XGJq8N|*!% zgC|4EjMZa&_;aH5S`3R8!s_gZ6a^IFyM2=uca)|4$@q~goXVMIa@^BB1 zM3P9ZtY2fh%+&Ect}6EXLW~NZ#Cb-n^wDCwVyGoE*vz(T^G0B(JJWELG{q%+uX14G z;~LLPgA-9bTn5|a+Tj4VCPakq-ggX5l|Ged6-JTFupT1l!wtI3ch(5cY!*MG-arR8 z3n%0@zE`7>#1+;k(Qx?jXo)3%&(HZJ6#lK5rp!iR^N-aZZ!*K+!qu>v6~|EvPHA9b zPG}N#&pLe&kB!y7B_=M05n(B^Y`>wM9NiB&w+9IOi?i()LMI=sUk{hFuMrblq@8Hs z4$x3Zxm2Z~B#j&+6CXTR8Y6x>t3%`_!n(vqq6XmMR#i2t1P}AS)EKf_e%M5H-{fi9I?nD254GWr&5scDz;P{n z^BBawgj&w2_75$Q_Ax{yia+Kw^>-53Elem8*u!BQuRhzq_ANJFB_}ok(yPHmE-Qb7 zQwt3CzPe%GkzMgC`={aK0^Twr$x@OLp9cni`K*7@Ku9F+wLSKXf^|e_WVT8wH!4fSKB?q(f5vfiX+e- z@tn@UyOZY2RHfO?B~KnbUt90#)>r|LZb7h*8uwIJ`m(=avZH|y0frMh^Tm%3~y@Su7vQ$K9koRp#F`=LE_tgG7WI)B8{S z9U>KxoPqY?;(Gl33nu?)DC76?jHcP%#a8;nB=T{4B7n)V82YM8`-H(m6Wr%k@DFik zpWj)`*Q!{D^<6!c5E}Z>#KsOr=``;GDAi=W+PiF0sU^>--0ys^3DyP#52(@dbFC?C z+I{U6Y!kc1b=FHY2=l$B@NBo$T#h^l`3Z8UNC1?TeZc24;iU&Y)l4`NLTM&EP_8#xvF}r?d_o!}=^6C~ zyR*3*2|a{{uG$Y)^)&?<4DdTqitd-}ct&%)tl|?r$$ja`Vnlu;7kr!Scnl_TQ9O}8 zD}Y@`e!~iGdqcWnc#@1@Dqp{ejPYIJ6VEgCLa>2+?FELM{?X zq9xbOmZYh}!f+GMqSaTY1uhW=N8HC63PfTf-(Mrl$yFCCb1?B<{we@_e1_eI4yL$O zk*`x!`^X>Fe8lVKi1c{F$%T#h4N^N+R*VvL4EYo5pSRJZ=vBl%ANozWMs3l1DHwKR zVWoZc}>-F*Jb$eNqaQwp+W zn+Dw4VY!UN0mXwX*w2Tkg`ay~J&@jTd~*{$XP6a{${mEHy_$zZK##`w#X|b4@_Ybh z1lba*;L9I7(!QT`(0rXLk+$C;dImap@s!!ZyysK*h8g{aSC`9i`*Q_dkQkQ6gjw3_ z^Rht7$FAq{R$P?TIwDaw#~vr;flYduV90rL;! zS=UWM`!2AHOET+~phU)3PlrzwGqfj$-P;~v{~S)xRxgtcT8+A=fT!YV%$LV{ddvag zQOULyO>GA4?gk^n_++2=FVT{Nv|lik*yq-g7it`4K0Wrfry97}o}iqYi3TuN$H#f7 z?WSZQe5pScC?frRqnI7nDfzjR;0Y@|1v60OVR7;Xwp*3a4qnBt%A7h48NHD^E0 zLGM+FMYg>%aRgc<+{}?3!L4?tbhm;C)x6tq5S~h9XY)A$#{_9r{!^A>n~}Y6 z-{GR_S&9Vr>z_v~Ok1U2k^8b`H%nbehN!i<1#@Mmq-sly9$?D+TD7Xq|n+HO%*?!}ZmY`cX7@b0-G}2%zH=XV(^DNz{A%L&-k_VR99PFD| z-b7WXX_4Wi<73l~3ZywcDye`VNvex`o(Q7-U9z(5}#41cOH97+O+;$uW)oWWi zBh(wMih<#kXefHMdJw_vxlu+rAEi|WE>qGA*67zII1S^6xa+oyKGFivCQnRuSmDbc zkt0_M`O@^R^vzAWpllgI-pU`68A2;M_!BRXmZCaHw;CUtS&Lgcz6&@YJppNXjQEvD zBCD_$IZ;5PjWLWd`8kk=J6jqZHv`n4<>koLw%kOMxcm0Aj_nZcO}v4JgQL86P9KCxKDg6_Ri5-%GcQ@vzoPS z+&|}p*~VnU8D$(2%ewM4*Ua3D)qT8&^au;q?`4ii$Q=aEsi1&yIXI)7|nnR)R_Pa<68?(yMQpT@US)0>71@XXdh z547c%#~IRNtrsl7q&z03@~)hZ_)F(E2kuL1O3$n_G{sM(DXWpEF0~|>wTm;_iyM!7 zjZpntlGqfj_cI0%5c?VGiLwz8p>72>aG)6Id|M6GC9CoBD#`0^FApZ5vpO^vzwMFi zT(bQ1PtD7x4mE6jes9*I zVJ|wiB+{q&>qrV>v38GT$Z@>{jkSgoyW&)8yi(%HcU_4?9SgapI_GN16&%CJM_GpF%K{fM3l~*+?c{7a;P{^OxQ8J?`z8V^}234y3Ws2{$3TrYwhq~Yr$-4` z!s*{Xpt127h@QFas-SP~Z&<~Xz zG&Yk@SC4MS4U9plm4;?ZlO8&jMVuYlftz&7C09;S(jRl3q$P*%G* zVv5W91ap_QNBn^0T@!ET635A!J}wN!d+viKC;Z8V!X57O4QG~}38wTN8-|U#Z>LD| z0b5C$bL%jK>y7M@;S)I%OexJViQHZ8ftaeF_VM!-%}y=$yqP);1`=rYZgW2z*(9TI z3>^k7cU<3cX584ObsbxHDQvT~`orJY$4o~T3}kC{5-DxLZzhcSr(ZkApwq>k%u7$Y z4l|@3pSe&~nN@W&qx!Sl541!?nBpyak!x5BG!&Mj70no9XBj)+5R7h8(95I4Io=kX zyfl?~wtUd$WSJ)($Eq=yQP^T=GN%V_Y#}U`?`|)k4g_UL+{?T1b3!}`KV1oo%@ib- zaK^1p&7hRsix%q~WZPbk>OAIa*o3Zb@t%JIbh$Ie91|unJ*eH z!!K=FvrS6nzxv==kL_~3Yw!KQ( zGX4l;;7NFOp97J|Dy((H@k<#H=@olc#nezBa9O*4-F8+1{ABA->PpS)FLDo4+713Q~(UQqUlO1J9lM12vRM0{i@1;4Bw7 zqpBd@>^Kjp!E+qow)yP5!5+=&S~t4qi5RrrnX+Ao2mDQShYbv)Z@?_HLZMfFlEjOl zzA>{7A5_Z5dud0wi5&4F?VtO%uIS9!FFJO&)Q}}1vgy&N$*eJzhfM;>Xc(T79pbMx z2h?Rl&keN`!x8@mJ6OE>r{DLpWa_iFfc65T@f_CKK|s{*2>mp|;R>vDW}ue-g+}E^ z4SQ3Mctd-i_ER#SPwne_0Mpz&rkI9+p~0*193xpe`MXW=c48FA@i#)KusBkN-589~ zWF8=GEoe(&%w2uZo`P@PxFA94)A#t zqA16=)?zd@4CmgrJyn`PZ~?O>7k?3E8i%zDVCbQ?T~<>5n6fqp#yP3Q;~)y@K7MR^4sWZA`-Rc-Q~$`o+5KjewdJfkYW1z)Xd zMS$kdGUq*y(O({T(cjK|)_?QBL%*6f@opznPpckZ0_fS=;3f~$(|e1~2UYLAV3qr# zsZ9$5hsZjHbEPw`{lf~mz+9?)yyYSDYb$nKyQhD}aq(Z<2x2DPn9uL1k!=|v{kK3fCXA*WlqfZSeswlh`oc9#GG zqJ$U}`LU@AJAHk(or@KtYP8*{xOUW=-yi-iy`9&cQLInN#zRB{C0kDOLg z1gaSTfeo$ieuJ70c&qxwnUOi!+w1M129d$MMC_=CWB_8 z-`^{f7$ke@_s4&heSk9+iq-fm_xo#p`j4;4ts%Rrp9O8m`lZ{qgGikN0aV}KHI&3Q z^=nxzr`jJZR(Ii;PJ=;Yz@usmMYsbGZgiE!!mKY$fT0^W=a^p}uPyF5y5NE{|6flS z{MaiDKApu_2J5dU!wm&GzO(sF{e$~A>R+J?3=!)kSpI+?(aUPLYqU^rd&2hC{no`9 z)S}EmU!=D|7DQmaaQ`Uqw3yX{sk{rP#VlY-%kDe3f0gb3c*D7r;FGoksMh{q^F!Z! z2Q<*13jYrDXa5ZJQY$E=8wk!uoh6k4Qr{emzM(Mj#kdk|$99xcww;h~>W(e>E zAoCLM|M|E6@fH8}w143;qDUlM|9l0hTY%*r&T{@+Mg{ev8cOuO}W9?<_b?Xv$jasR&(qyOE%!WrF}|GSNU`0r!*|2ouhuMoX4 z87F%IY7s#%m<9o)Ko?va%Fg%Z=AbiyzPEY=S?A19{&o2TxmgQ}k&pHWaX2pDGgf zw}1cVzVRQwx47lj=>?meSzFTfXjUN7^LLqAtZ4ss2+u`?-dH>-^#68a`9IO}0mw_K z_Cp#HE^Fl`E-R=M7KtY?^5>Ag{_`pRdLABDKyM;L6CMf5LjU?iDL?2%H0}4(_4v1o zXpnEKHv+oZbg8yK8(^r;O;;L^nVoKq%$#+%ZKJs4&UEuffhjb_izCppFTQ2%9 z@>{Oq1E{HGv2YTYhH0M_?hMR0ZY1sbU~7~DD9@}0pq|wi2OonP8U{?ZWQs)~Hf=Q+ zBP+jp`vI^D=7!UFZ?Fm$LCJ)pb20)GxX8f^TKJG9_t%^B!2DGR`s03#1HE&O)f1(5 zJqy~8RtJy*M!j#L^@YW4k?K=HQhhPB7fC@t)?ki?BmrPb>3fx961P7&k2=t+0d#>r zR(IetG~A4gU`QmB%|HRF)Z-!CESPPq$KiZ_nvnZXo>|oYjAv$3lniJ!=YR(?o5*5) z-K$WfG6SeMybXu#h!qFr4AxPG0EO=wD2hS3y|4zF9OmH=lEApN;3zcUInJJO@pb^TuZ}MvnBot2xm0 z-2;#f>d2@wP^U(Jl^Q$>ragB_F#mwr0Yc!%&g_2x%FIAN+%2j7J~`FX$wPQ=iA^NN z>DI1b+_^)>P-^f**gc@@^Y&`0C4O;4?^rAKz|N5l-VVyM0>xH5gFaTIi^uGRk(iV| z?^CH@_8U(m*w$lp%P+jv2LNDo4x36Tlv?w^tloFr2T*bl^COilz>@71k39a@!v;wT znxVp9NBriQgW;qP5Gov(sLBny+;zsV2ktDi5n5QQ&2$@zr@Dc_tlV4eG9v%0+W(eE z_KA-t98{Bk^T^Cb()smgAeSc_b3kOU!4LZw=$I0TpO6tLK`J?>=% zu^J@S*mh~YJuv?5P)Bf~glF0h|N1-JIS5Yn_%BOH-T<5f<|}a{!zbMLfO+6_A)CuE z0~YxjKRMub9i7w<-2#$O8GsXX42ce_VLxd5P{Lih{PqOp0n+DI@TWlOcUi`z!hp)r zNX9>v2b*jZ`o-G}^pQU;X}Egbu99+*L#hdt$k}8b7e0h;J@+X1|4gIz;fbL2qqaW= zwXptrDR2TU>ZWBPWc=w-51lRW3}hvN?ZI^in00A4KqDDilh@AcD-3zi`&uc?kNY=y zOqMJ$H;T8E;HZx{Ms2whZ$w%mEtv;6yy$tLvp3IoH1B3D!v`PLVJiB;ImSHabg2bu zE=RW{oDR$St?bgRAOn>c&nU-}^%*GB?+)~R%>dhk@(raxY5K?qR^)n%6$L<}131X- zgGX(Yo{jmryeh?+HPbqXb05AUtxVAU?n>4x3X#>^-9i%P2TU_tl2g#QL|W4~G{APV z2kQ2xHbkG;fi%~3Lm0hdBGsP7hKEqec-YrJ$g~~ibT1XMJzr;S-mq6q=fAkJv;zgg zf-#}v_DTTUv40NWTG2^LwE&Dw&ZhF%o15YR)V6PFXavgpuAsC(2}Ec(Ydv&1&1g6NmcX2^fxAw2{(1L zRTCkS#U`}5kpPkSR4ETZs5ZBhi!{k5yk}12X=^YZ$ErQ?aiUo$GZnC{! z)}F}-;WPp6Cwraf!=_T8+nFBElSLnkESH)JW2c8fya)hGwYTKJ|0feaSqV+{u>l!M z^vE8KB8Tw>>0rh9-l$vN7oLj+nLgs1UP*uo>#h%VGx;i2`-2EH700gcB>}2nAI9X{ zqRy0qD%8SpKnuf@IOEX3jy3?WG1CfH+d9jtndHy00EEfQPGN@r=PqKD2t8{CoBrNL zPN6R_{pF(kgE1>7Wnn@kVaVIUF1{o~ADlN6A8Gyx3F|ZU!?`1R0H5Ya^d@XL^Q9-I zV^=6Pb`&ksV!r5u=Lq_GJ}{KF3-D(^4>C0VlS+0oEwdB}z_`j6o17u^QyR`iNlx2& zxAyVwwlleEMh(kBJm8PVm?I-@>jf>PjgY%EUpx~aJgUY{5%MO0E~OR1qi?isM}cNP zX@VkUXV%aPH)D}sF!oQXS&M7#1BII+JmC&FU62StIbgRMX@>1kIPT-OxrW~R#_IJp zN9|po)_}05N|Nhld}w=y^eAKBL+R^l@J<0X{V_+*1=ftbqPjk`7X+0kW?#VKD%yVO zTlor%u6rr`BBUdufF@`OHGPv)MS2$}n3emZxY~dBfTIwW_(kdb7Y8h6DPNwFHj^Fn z4)=gQ;Txq^m~RSYacw&W@)|e{9UrIwu3;S#4*fSSuc{z?mqQQt?&nYz7$abTecT5N z#^^D6-pDqZu5kk<7K#_Q>r*M_)DO4hyF-9_mQwrt2ZzKG&0{2)O$ce;Qy;_;`@rHO zu)oeS&Y96c8$x19BlXyg;LvfK@Y>rsFeVM+2{Tmxnb>=(p}ARi^^E@_7{CrEzolC~ zc6I&+c)G}AtsX?OZXBjWrTA}k_=1S)C85Z{#s4IOoe}KfJQs;2w{>6U-_c0llZWRx z8gGEmH0&-rZeQl=7Pw9}lg2gDi@eJJP=%Hkm;;jrB`4mT$aTYcwL-y)L(WN}tX$z@ zO9_DO8?8wUK!223@r}*DD5`rr8cB=*D!XJyAX#2BZxsmc598Q*jEO|je#v93DJhQY z&)q|Y=@c(ue3&3WC3660!&f7`nCS)#6_@-@7Nk|uMp-;NCapny#ZZe!&^e)N$&$lL z(5=c?Lq5BL?cVGJ*WB8rS&m@2TRJO z+f(yzH!oYUR@Xzhy0E+3A%<*4`8#o&F@Jf0TgZ>6H!iU%O@7nzXqj04ZfMc5wMRPI{W>a|`&A_UtEW8T zLNFK@SwbjHLa0)#LNO#lQHz16>G$8qk78tEKl*_pgiXvRME3eAnhN;OWaW1oO!h;^ zi}U@9`vdyc_7GRjl!lok$G(1h14H}diwjQ4uoj@QPaj8hxBTMXkl7+o(3=_(-&hjM?r2q~Zv82`|5Ev#q2cIfB(3QU{QzugTIv=n;|1F(YGs)4Onm1}nj(+} z>u!wPZAfqxy}dRREh)biI;H8OJkx&4woKYw5fo1FhMl!)P&!YXP~2_4B|* znMF%$E4jiFS8Qx+z13(sLFqaU>*&?!zF95gY%6u^D+5g7N#)a^%z3blCa7i?Au-;I zOvG2WCLy7ixo6Px(v>!L_#l=qf#8`@<%6)wka}A>G#kO$Ls}{7bKDIo`^JLafg|DUJuRD-oUxZ zD6E|sK1?&8OyhP%*kp>{zaddFo+#sAq;OZ<1HEkG7h>cs7YA>s&z$>UQ$;`~wCUla ziqn@Hgso6HOT#1~_{9$Kkyy5rO-H)>Pr%_7vNadLoCXzEdN;r( zQp61$3-=liXLDqG&|*93B@rg{BtCJ-nmgC6crLmnzIMULoA0uxhoE34Am3x)5YyLj z(?_0Z+Bc<-pfJ0lM|q&u6*x$uRQBAy=g72_hLPp|Tj;%9f_qsOJ0-K_Ym->xIL9g= zKa)>Mt0%2grp3ym zwk9L+A~^kJMs{oFv%XYq0a^2uC|#yf8V{{ia&D%YmHApTk!}V%UJXA-*mRmWP*z_7 zn~lOi1K0?dqPa+&6e|7jvtl&e>P`Vsj!5U7fB!`FEm6AA9H?%{E*AU4VKES2J`s+z zL`CbpKv>`%9kRmmAkHf=S2jR}58p$p?DQ%D!^5e!X1F0lZk`VEECBe!tnMzI`-+j zsdOgVQBzGx^;>v3f>dGN`e2)Zm%x7=6<}a7wjts6U?N7Ummn+(JeNAY%zQ`?iO0{g z+%*8dpUlTaSf~_F5#00YFmS?~?060$7E8Z~U;Q$)QX}QLsO_Xq!lDpW!QR!FD`ho4 zxS{koEwG)cj;Dta-RcCA%`VerLNuwfAw_rRWkz1aRTtop5kurMaX@fkzsVFfnI#^f zrPJc=_Lco~n74;RCljH8!M{I5xDPyceqV7nS>uNNo~s3-WUv9a5JaCWQD2GcN*jze z43`!ZhcW((1#qw+s2-@Wuo$5$he|j8*_)JeM_}&sm1dAF zahr%VX{66{$)CMxDm~THI)Ryq29G99tSD5CZZn(HH`IW?Ff)OL10OraoxdLg*wz|0fVmBw}#%6adxN{yj{ zbLwfQBOQE?n2^Xk(D?)iesE$w?Z?q5jI?T18IT`_JD~~YH^Y1e=IH!t)-{7S9wv_{ zaYB317*fEMle(F=0@jVHQ>X4r)YQo_E(ulRnr{)qLmu%SINQ&v*V2Xh<*d3MbSIZ z6XH?!YE^PbzSp%bpQ3te6SV03a*O5C0{1{d#K`=yDYPIA%VTgMen9SzKbeEPPDx@nykp(MZ zegde?V%F z>#$hdw9Z?9yiJAzsy=$J3B$c>00$eQkQsgT`CRd)BY3(Ng@w5MD)d=bh30p>I{hs@H-C^KJNJ0a$3VzV9-T`(>}Vq2t=9;mkz_=z%Qu1Jz{PPJsD+X7(T zy4^DHzBLySCq@m95k^!9V4fLuy@Y`59LCJwO`8+L)|%)`d-go~6B0DOLO#8`6PZel zL`6%gjOwK!$viNgkgk7=-IzJX+Og)SuK9pwa+uq8b~%#U@nJKL(eVMo0R5yEsK3=N z>(e`6CBVkgKg(!gfBQ|CG)bDLIg$Q?SEx89Mg&UR~Qcm+&YkKeNYP&s=ktDcLj*OHnfB_^2kpgCiy ze}8Hk94%BB_RzJ7-t?~=h4QN%-6(37bxQl6xLxEnrmN>6am*PaAgnOHSgS%5HbG6P z%qhGWBU5AfJ!kZHk7M#z3>mx1+juhkN^pX>GGO!pATSf5)b-jH!W~eJX|)g-Y-%$n zF)?FByabgh<%tMePEr2B(88Gp3JerGGK_N7d0(p=17#{?Df&2*!ivWVtg{HjSjotO z55=PjfkM{(?zV>^-H66`36Vv`NIlaTn3DSMS>!0r9>JlN(Cu1pKn28)6T`MDo4|^O(nYy@ZBO;HQpc8lY+SSNR$!`mEq8 zp!i366jbb3_R4NFqe!7NJ@i7EK-&xr$z%?=i114+&R8YRNVH!mP}GHHau7tS zY&&qh`CFsUNDY)UpBCEm7PMq|Q!Gjd;_9h-Vh)g;&E&%2IcRlS0OjJ+YMQdPyl!A8 zQ4~*@Q*R2xq!MBA;&zomq4CojfH0%9ET~4o&&{Qem9ZGtuu*6e5oAlV>E{RFm0Kf5 zy~z}|$mN--C|ELE4O_PZfbjh(N93hy)k|hBl}z60&1MTo?mJkWk&i{@DE##!_q%p> zSPLd7SZ2y+x5Yd$k?#F0GmXRiD2#jo$~L5Ho*n_PZshwkjAhOXedy;dm>SyjHQCx>wH}$9e-jG5lbRXUO~Q>(&L(04RWkJ zS@N=)ZtQNFB8Uo+``wP&gd&|A0G6%~)Wc%4xfn+R=EA9~u27Zr#}g#TtPfF@pX99U zB*a*0_$Bk7mAp^f2vGa9#O2!$X~KfkY}>C^;O+|Vo3LzN->;EzFxm_0Ak$FL)?LWQ zjGzKF6=7E-nQ@be^B-Qgy$BjXwjgjC4{(2Y%Q(|0?|r9I{}QPhiv*fF zHLp|SlF(v}42q&fevIOcz4F?oG>M450{;K**)a@T*3eHVJVRn_y{kPD5%h$&xEUUX zNP5P0m2QM4*|nBq_6la4A7q8$GmafWBzil~O;F9nkGwp10iqxSUM%)aTINt;u`*IS z#jKF?K*~d+H+gz)CEw5rg=(w8WSi(t!dY5`F3d}z6zve3)UyEYaK@YB)1h3HWf7@8 zuW1qINJAVy`Fer{rsbj1V8J|YN`HalSt=`b&|q;b$B*(p-(~_>oN;Hjmr?1Qc}Up( zlsXS#J=h=ACllu%4rTG7g9|8$YCsc7jjm_!xBJ{o_{+x%SI4u{jl!7Ai~w9#T! z;ep0PU}s*lgtW?u*aPN;m1U!J6TzFU6R1UQUB0OH79h4f_cCRU*r)D?oznQ(x`G8E$^?Lt2s~===+qV#m7IlpHrI4Io%__P1@2? zT2c%tO@BTY)eq+|imDJP!s{yfGkvMS{AR1+saI=1QEHn3lHZjr1oTQUW^wsnN~wx0 zDfe7!gb^$G^A7;CfFvj4Nnu6u1T5-)-H~|lv0JM*1e(y`-cTDcgrIpn!2CMff33_P=@ZDtJzhvf-2fe zKrHDChVNV4PXfgEdfuh`bYOV_OjcBj7e*L$9|eK}2AU?MWjTP-K%Tm!553SKli8e< zU;a5ZEi3u`tfmhhF{4_QQSY+jzaDyMxduaVHuDr9hw#7o=;1edUtbwy?VIY;rHreA zEuAms)iWdoJpnsJQZt~4+q)<41S`ZOcj|xDeACs?G#;ghLp&FnQudd* zSj_Hug~D+vb*xDKU>&qam70qw#FvG}0$MUCv%Cw5l#ZOgxF9((4wY8}pz>=63}S?_ z`8%T5U(yFBb*{d2=haj%*v+ZrlM*ny3Peqkl^lTJVjXQ#J#)Y8(_BAS6)t_J)8Uud zOdu*>K4;NT&x;Ei{nJXxQORONRimPvq zpm6$X3H0VyUtfUgbkTIt5Rq*0BD1?wl@6MZ-2^XmdWvRQN@38%`H1`@N;+_k%#LR` z2$uF)8l>}isaOY!g={54>sLU?lsSJ3P-)4^*`p%Zg->2Gzp0r99TD|?LhAaXr>)>W z=vmx|ct!<(D9>Ws2j_;pg!G-w`E?}vT(Ot7{_vCpp6^p3=)+oZe8gADp+@@iwI6+u zRUQZ+sOq+fEPm%t3(i$zxR39h3=&l!O?~e5P0{;3Dqtkp>D32I9hq&nR4>78LmsX3 z%52ld8#62Sc!~STrrwC5y5Z|oL%I;V++4yd^+{B_ouZm(u7^%ZhHfQ7f=d2`XToaH z`d4OG8&tF^5d^Qudc%96qx;ruUO4hF@_!irRI-nvoGkBl~OuPf6EDM>5uuZzMtg`sP279~$R0T9KW7bb6V`1{e-6_lb-C#{Pf24x8~38g zD=@Tow}6k)G26z76)%fWWbUH*qDgPARQszHDPrf)g2=%0gh8*_{gDhoU?LT%DY81= zl1CfrG`GK)q9MxTcNtACUzD`p=sEh+opB$26*O!;wTSPwWBpZcuX`(SNdRl$S6GY> zJ*F%SWnZPBw^o7q4oQm}@q$B<=q;)TIgeFY0g!y|@t}-3{fw!2*VhczXLx(HDXiFT zYEgAaF@;jxhy+>Lr{n$_&y~LTq@JnW23swLu#ZN9I=Rn&VkrKGS;nd2+D`RACt)0b&B zxD*?%P>~tab|2In9jOj*EY=x>V5kv0dB)Q5gv!Xu$PqFQJ=vChR+NHC)D<7(2>63z zLf-CNcEBAl3@655eS^!eAZKYkCLWDN2B{6MJ_lCR%#(#}trK)Ilw z*myIDd0g2V;1MBYaiMwxk(>Mf|(3Wh7frflp4AQ z5ll7`XbG|@o){Ve&G=d;-WIz*t06czjBtNR*>Mu@d1As)CLI$VUOTU_iIt{Tv-U2Z z;(_sO$_5ojUnU!$TZ6i)38QXJFtNK@qN4J<;2dmQ6!P&TwbO>A6t3ccN{)Gk z3yDB|kIAS&69D9gBcn?Hku zOnF(DzW^8>4PZB$in%vYa&{UlqP5+Q6$LOnLqJET7Oxx35^jgt!`Do-K&O9QsOs1u5l%&NMLbIwm(!6P39<1v27C~&8anuvD z3VUe@40lR7F3Uq&K!lv_#8j4DANP@7ET3}HZ#!d$j}vejODZ4aNV8ba|E<}Yh4w2p z&#i}1r4?A;)dz6*w0%cuw-W31APFQ?+N^5lEJ%dGw_^O+FUvwek3MNw^X+4FH8}a? z*8tZo>5UOLH;aE0{+<_XKRc7tY?_ogKAlWydK|}NY3K+XgFz87>k$Z4y3~K?Pn!IV z0(9}W-A;N)Cq9(1`I~AQV%o- zK1XHT0r`)_{!|rS9bW)iYu5$x0NKPRjVzBo6-b&31Rex@z&R_P*qD9LoyOq-?vh|r z3n;{?;%@#tS`oz)ZrdU2Xla3}nftuVl|3;Ta)r!qRRzy5&W0X+T|`DzP)Rf7$A#}0 zVs)dgu97oJQW0ypr~xOJV?%*G%A{Zv8A@?qSl22}KzqMG`tu~i>UK`Ogh6O;j-fTd z@CxZb2XzAaaE}8>c8Ey=AtN81JBTKxmVnt)C7ihli_(c5Cv=jo;J_>Tt>j)#&I%}P z8)@3^a&8u0p_NkAt%g(2vTtQZn6Oaw+UY*P`zSOOz*!S;oB;|R5V_lcLU@r0tuXg4 zuCLp?L*sY|!YOU?3d&h7<{+}Kt~-}j^6?L1MlYaXjkT*|q}3T9bJV2C8P|>yLx0Ea z8gh|06uLB3DX|3bI} z^fkwdRlOM-`T5&eejbwlC(105{@IfK=0P8#J=C;unQq9|K06EN-^8shzv@A&nz(e1 zeUU>sQ2~Ama*F0QmFATr0&8&E2(0uC#JJh=6|%%{hIvbF>1Fv*hH37Nunr{lS-Eu} zE3jZt#=U%GW#GhR;tqxpn;*R@3C-Wyd01|!?ANLGJBP)pqh~1p(ZcaIEG9v~;auz+K zOhCUuR>v@{*fr&bxa(@njOGc@-adJN#%$7+ldvBo@+0GUX4N!MOU3sa+yKp;icBQc zz~s~uO;@HZ$%kIk{e<7Wgy<-VLxa_p^|D|`4<6h#kI1y}5{X4|Pr$|60gktrQowrZ z{R@ZeYS^|@i11X|-DK`5+ac)kio&kADpa4p_u|4?jc`B&*Tq*)b0)uV;YVrg)E7;h zvPAm5xqZ2PX$&Sa(==J2%Cfr(fF7MI`9M$y_K}xqRgio6>=J&xAxQK>?favX;oErt@sMWq%oj~f|2X9LX zmf!-mw^3)xw-r=!wReSYu2fvR zXfED-XR$OFFqNg5v(#-{DITp6IPkXZ6rBBgBqp~!Yd^!l}>3bhYvs!@T zc6o5I^rS?>kyb4E*7#=%|D$}pirK10?ja5B9!H1|<`s=fo>4>1jF(+$;$j|$joeqPzfF}_A^0}FTxJ(ybvlo9jd^8)&^zKyZMVb7ZFA^sV`UK32TzIJRZ=P@6 z6mbKG%EhMiJw!_)rvF{@Xu7nw7cwiEQhmZ97})p7n*L|@eAWOB3gZdQj0LKKCyTpv z-xchOO5tZpKj*bBI5ctl5it8-;{Z9!|6=dGqoT^TchMG6K?Ed=B*`F2iIN2^$vNko zqvVWA5)hCaiX>752~}hyN|spU97J*s1qD3BtY`X$i*;LSdvg7`3qujWI}# z=1N!Ahd#Fjpm)eJfj&8IEdwd02QmYV1MW#1v)7$-Ex=Z^^R3XW!f;nV-~r8^IIGf1 z<6yOhG4~u?(iO(ZZyGRIV+-XE86DT2Fkm1qcZzntU&(Z~J* z9`H5Ll`{nBd*l0R=&rx2tQY}~=hy~98$&){6+5HSlNMq9;_N7`Gl@jJT#NP+~E0%1?SE z!1GkI?6=Y|+2b6asHh*-%fX%bS#w|Dzs$m}g6?(c&W!UkhD*Q0)+Wq~52XGwC*dD9 z)1I?jk-LY|mCvH_kG7Y5b!a6yuKUF*DKX|-Fm1OJ-i4!$5?ILSaBn@c&_8|-=v8RR z>1Kc$0VzhX`yta>mcpiN06OJ9a|+9o6fdTHb2BLQ)oMHDw|6(w&Y$o^+-}+@2PJv@ z5?**d0Z^A9XZj@WhQuX%$@j#EGhVJ4+WwGMyooJ-awLh&i9DPOMsrHzmNh{s51#xLE zOMsGNu|znkFMx^8iRX?gf^x#d>;|d^Cs1KeU_P91mO=C%vjnbR)r}K~O48s`DqV_0 zr~sWu+OG-O|FW(%X5}bj(1u1rgu+b+S*vkP%)IMyp#5&95$TX8JBf5 zE5iVOhb8xVQGuN-^i#n-S$M${RGp@1#~jd0PtFfa(@iY{A7KN)#eFYe=0H%6?={C^ z5umBa$r=F6q3{CweA#@X67@1d&mL&bx}sfi*{e81wZga8Q5!=@G2TzDW|YP=K(Sfq z;s6$Oo!t{3Xb@;n1HEc5XFm~(0Cfkc+H2HsoRef%tE&WMIL+ zm^m3)&-sRw-H2m=(WDOm&R^Vc_OJ?TV;eAuz72p~lnb@XcYqFLg><9?Y3^^Bs+b94 z=0M=0kVBr!&*(&^ZT!|5$a=zn%Acqm&vl_KRoeLbM_L(7hJ_G-yg;iD&^a!vdpY^m ziz}&R(ec1+a4?7Ts0+aU2&ESF$rV6Q15p(d4CtH>i=~%J@@CN}2?$z+%seLBUgk=` za|J|0WZSEC>c1>C0O=GFt=zcBcvTGWmdMwi|1N{F17OCyJZz7jvFa`;Q2`?3Er31! zUdje+0_a70seSOd79b79e_Ztrau~@|z!8{pQ=0?;k=AtAfsHZlWG6uO=mEeeGNL`PN%t{(}&UZjdEQ^Fb$t34j*k0h~H~fU%&rJ^?i%;ZVi|lA4y1gJ^`tF1CtixeTRWj#~#i6ByZMk^gWO(oA7cz&4;v?R>|qC2X=JHQZQiNY9I^!4dA^B#Wqw|GVj zfKvPIBkn3jjGw3l|Hr4TBG7bWAWUUpK))xTCp39;`@Y!ue%n5fE!@E!z9z+)00FW@ zP>eSckUUJb>Ums(bFG5=(jWmk76Bp!PioG)_?+VlOxk}xVa`Au z_)foZ$+)7(LHjT zTg3e|PPItFM)dJ0UGs%nf#i)s3=oid0?aoaEYgB9Fe=v_BH-dQ*Y#zFZEqinyT{yoU^aq2E zCj!`$e|XB@Ft0Ukc!^m5F!_7`X7Uag5e0|96m7;KHJ_7RDosQ0l`nYGg@FuyDlZeS z4`KKuA&&MJ+hEB*PszVN`}DQOkZ_C2Usvk$Y|HCAHI{@V9q`Ej9WWx^7`XWUZYF=T znBOpt0jVCyeESB#R{7Tf{2lV|Uu*h(r39C||NdbBPQdSNDx_cj+Uw^_0C#HG zZ#VLHOP9O<-zAKn{C5|wZ@vHS!hd(+|D<~XR*$7AhFT{QXZS0!{*N2(VdyuYiE{)J zL((FfEaty0^2?>Kt6pE~S*3qC*Dn&VpA6c6qmcgFoxAP>{MYK919g#H4RxO(#J?W% z<~ixL!&LvnVaB(`xcxmI`nRV5m-ME>8!WbCCdI$hwX&Fxb04x6Q68^gj|FwmG z9ghE=i2vRTe@4Ln*V@AKQsa(W52&h$f4b=%)E#Xd`9h119XH%JPP|wA{2Tg;D1au0 zI%|I?ivLN>{}zv+xZM;tFvIZfYl!=Shd7m?n4R8uU2Q{6tOn=&DipDaLRKe%wtUh5 zBmXq203*Z^`r=)p7N<3!q}3p5FyE`LYh;&N>xXjpF&dsJ@{ciw0F|M?Ddqmn=l%ZA zx6#-6(&zAD>c6>f>TCCH{;z!Z-xB-(Z_f95sPT#(HZGa5y6&oj+RZWh$lmK3hp(^h z-gS*5OKV`FtGA)&~wvT;WTxwqV6v8A3q=Y`P4{LF|?%~9PEy1}JM z%y4xqs#o2|pt7kZ>i?hc!A-YCR`<8K0I1saEAQp=IW;=LZQ15E`?qXS1$k7MY-7qf z_x{@UU$EGpTik~QWTKXs8qyuSeI_%@%8hHYr7%OMjT|i6z@58`1&l`bY ztx(`K>Mt{kx|X{|eC7Hb9pI9{WK(>{_`iA4Q7{m!`4Huc|4Qot7LiaCAb)&XaNMY9 zQWu>y-f-O1s5)0G51AM{QYmb5dhtJ5Rkj?KW_kOgvC@&k{>Nakm9YewiSPxol0Aoi z0!jbzDV}G~1HtYn_}Rs;uCceBj8D{nUXSne0VN{@%`ryk!Y>{MbsW$vL(8gUPPGmH zA=EtoKY@{75$i7iWAj?Mx-op@Z#uo#OoueM{~w|0 zELLFt@Kn5s+|k>`WnCC>!>=FH z?YpxEhsQT#AqM#DQhWz=w#GOv%ISDD4 z?ff(hCs^Xb=UHf|Hl4Q3`;k70zh@%Za)hVOn_7!CCrw1JqOwRk<QEDFmVZN_!+fA&NL=Iu|RK0ICgYy->+PEKV`uSR6!Z4tJ}hJ7lo`$YFH zu`1gCDJGGh!|B{vYpfoAah=j#O)29FHN(4E@>fQxUO{_(WYN=k{GC#UqGBRXk@2~g zk_b4#r5g(!Z1cO?7Xmh(~HoYCwj)yaTk76WTO!U@@BoEwa$5Y&sdEy2Yk-aWUt7X%OG4btT8iCq7Er zd*<8!kP~(GbFm(GdH*GB_$T2H+<%;@K^ZI&J`dYxhtmg(8LQp5z3ZzsA`l@Xo)3vF zQ)NE!7-jK_p07pObniM?qle9qI)@%rrCV1|H^E~rjQZTp#*j$`}cLjg@WuAi; z*E=Yt@kh!>GBStB7KlDcaNgHglI+gBN-nG9M>{@yL;zto9;iF%J3E6;7gLO(}>rLHELD4v@`-PB`y512 z>~`tePW_iX?Mh}p22R*hTZ6l*b&7ghl*JQ~z36J)@s-4=nJ?Ayb&BU=^9x@7CD3}X zvBlhyRG19E9SHwY;KDX_aK8bItW}Sn{!>gya=?#Z3*r90+M_0qW#9k6 zmw|Fh4rW4+B|^;zZ)yHYP4_2oe%)ze4krX#m5p6GF0p#bLT5zT`9xKVA_ z5B@gmp4i8$J5eGR5J7yan!!Gv)-MheEj=HymB9b7nCHIO=Vu4_t*-@|64k|zts+Bd zu8iN#=FL}yWBbfkNL9c_9{RTd*F_%(h12QJBGnYL>S%9LJE$?plo0aNb`N3hC$cdy zA#!B+3{hRv*1)f`u| zamt_O2f~OlBv*YrSHlHy!C&C|L-2DDU~o^S>xrk`=BmffHuim9doO7A)-{yB)*DN@ z_MKg53Mlm0-4hji`7Y%B{N*gu+AIMnhWFZC+K)yRvrBgjo^rNSp?xmelEm$I9%LZT zCKgw_#IORrO8Y*@3-c`^NKKpJ>kn~ChdpQTbBohr|5#1wwFXRN4!IQgKu3!N$j(Ip}4orU)@xw zQRzo>CT*YWytdTDwgZU%eF-JsvWPMcs-8^J{ajd z9>^EJcg6j-!*0)?8pR)&gg0-$$l#!bIzRQr?sa?`*h550vJrBiu(Rb_8(dYtC3MlI zyRf?h;z<^Qlg;t5ikwWV*EfmdU$FNhYB*4F1HGOc7L$r=jd7#sM93XIo5F2tIj$7= z%dV57BDJ=#tV^zYmC_9_q7ZQTedLS!&S^7FZNYDTT!B8@7t5FBQP?|*ye$Td!r$5- zKl4e#SJU{z)UYecs%Xo0#N-zc!F)$NX4o_nYseYu(3 zcT`f_%Q`u8h-Zny)e^Jfyl3=w^fEA+E;JM!^&x z;I7DwgPgNnp8+wsI9sWJ)?>}@P0+KaisegnJ-beUHR$}w0V|QP>63Sx4q5R`UmDk> z-sShQLlU-eBc+jLXEzd8{Zem%>EcT2{haF8aQ6+i%H>os+^5vKqJQyh-f~K1&tPdj z#(wqK0l}lGK+XlO0#b>YZN+zq8%BOqO@ZwN*~&#}J1)l(avE+=1cQnpmWvnXKegg@ zxO8R`*^mu$cfuISdPu|4&3I-@xx2KI`gzuh7Zb8Lr&rsYH0P50V%@$cF132O23V$A z?qr}wK&`86HKacPong98p}+YEK#|z;{@|PClYnE%l*^TgPiZ$6Y)@3^!i@|sUMZt> zkp2|gVrEw1%$D%!S=+f80j_5n+C@HV`)7=_#2diK@<6|iIf zPQ^&Zc-r!6t$NPP%GVj9SGJFgZsx4biaqU{a~xnn%Ld0xqZvi zQKks<)J!n6#!Ob+HPAX}xQGT}s=D$~f%Go^Ir{UHp3L})Z7hn)dGDP79q_y02fl_# zU1jwkb~PVJ?};AWQ5JAFtF{~Mr=D>{*ID{`eY^d&!hUIQfNZzgPJ4Wa_3f@V^0V_K zpMoTPl|ST+omhj;fsju>a9?%tDN!w@6TwFY23auO)Q4HC1Np_U;G0qg!eSz($B5!# zD?`b=3xyw_)J<&S;Y({Cic-_=ob{(`L2c8$RivcnZSBwR=pXnH#ujQ&kV;$b9Y2;@ zQJ6S5;nbd5cxhGV@#=DC?xpiVOnhj_*PXQfa!+|!H4r>1_-Qm(e^-W19b6m?ow(j& z1D$60AcAupu$w(}D~z7a`KcYoLL6x8ORs!JKMwZp64|i*^ilg|^b6IUQW$Y+`Q zQU-0Gz=07vxMH;yg-iIRvz${ToIfhv~#b3@GRr^@T}PV^_lF-Plto82GpN$M}0iN<#4MM47fyX zhjlM@zxFudj~u1!e;M81*5AU)981lKUt`Ttq^HL%Xiw~TMg;DE(r4K!%x7s>H^gq#^aEb42 zsR{W%I@ehc})l|rNjn8WWm9UeiJ_JV1*BBBEJin?&Nm#muE7g-nI`mIIRn~w8+O3-VMuHv+EG;m<%R{Lp?mRer;!rTQ!odbl3TkO(ylz zco!|0i!xNNeHbn;;tD+u3F0|dXxCWR}!dn@r=6wQbBF?84S#e%dB3%?K4y6NLA66 z-Kh;$0KfExmkABeRv&rdTt$TG`gnt)sQa%C2eG?qHyUUuZF zGM|STLx?hoJ+OZ@rnsqHBRxv?e;_^n?gA}6`J_$S11)8H@=ACrir@Icjuj~>2|d3{ z3?^NSx7r(;E8gp${W>H`jnC7iVSkse0W2tPzLpkb5%j$!$GJ$nX1_hH#U(+IiKLXj znspU(`aAXf@f)<7uk+<`b!0o3t*NOSla+b?T(jILP&i=f$s@iIITN3lT1UU(&Zl~~ z_rxodC(3BG!)^8-{#xdB@|=>2nNV%jg<~k=Awj@WFb=AMCn)Y!;ug znBPnezE>bu?9%p0u0}TK^E#V&Rm*_6@#ITVvfU}(9}jBO`nTW;%4Mr%pX0lU(Nk~E zFZrKrBN1|j3b%O!jc-=TUc^~ERG#Y@#@?9HF-;UteE^BsakcjpIW~E=8M73>&IV{@ zctJgXuA9%;B;i@$5Jk0n)F?2VB`UC-C1qfUHg|5k39z6e$L8V9_N%~FB`%@o5A+zT z)5E73raP8r*Y1uyT=Xvr&`ro*n*UC>+F)^WOLog`hgPoYkVyK-r9i1|Lu^;SxtWMS@bPT}dKRW6VRrY`K%zhpzoW}Mfou5%sc&ss?qRPHS*qwmexWid|8 zb=|{_oXT#T*c{Hq9=mm7OU*%9BX=1adGo`L&>mBM@||$E%^4XXXF0Z z{z45nXEPTLHz&fiw>Y^08dqkBL2zKmO?%X@kL@RfcO_)5!I|{s3<5TOa?IBQFgcd2 z8uELfxQr_9#o7BLm0)$1rccK^UWugkc0>dj=5%zR$**%_GytPG=iF4H-Ij^?tNlWU zUb@*1JZ9%y+Wr%k?1n$*rl-(GXieNYD{9uL>MCv#-7k4R>{;jV?za6oAGSO|ydXX( z?0824Nv$8+wb{Uuksa}_;lmFj+FBtH>j9nQnR7MU1?az)1bJB`!ylrbxCRs2a z*CySyA9DTFdcElZ@f^igkdJA?Gxo(+9AU0P=Vij4#5Y!;$R8;yr?rzVuZ)n-N4C@Q z;ku<7!K4exM=q5s{sC4!WZ^sfor=QQ14u(7{>;7FNz=044E|oH z*RHe1|3Dgk!?Yy60TK~h?|ChQ*pOI59(i|M?1dtB(^AwMTW&v2u6X}QXVo-%o}d0= zXe6}eL0|u48tAo@Mnkx>hc)RUrDg0>j|wbuojhI0sdo=}2RnrEftc%Wa z^>p=X6EE7bSyX#>_Ql*cVL ziRh*JnNB2T86rf%G_Udj6KU}>z{OHA*NEJu6Kexo;e@^Cm8GReSAgvLg*?ulBndBz%cM3e$?>nA7xB0D^oMZ} zh3XuAs+7rfTRv8p&@i0NoD6gh$1WwF>r^1%H{@1%XFv7{p^X`t6m+HwJ?g!x+At_!ki~3x$*^rg6lI>cOJXQy zSZ551HHw}Ih~57f?i*l2M>B8NRo ztf9~o?7VT&#pdqa<+B|IcD;!l^=^*}8SAhqb6x?6+%#(YY4=Qp$3tTNvUNTiHOyp< zXv##ak>Q1#Ab1Q)5*Ei~h(_T%7>v$WMiuU%i`}Eq8 zmih#dtm3Ziiq(RMurzgl{2OHbgb2HRNO3r5ewqXK-mb$C&fzu<|AQh48v)R?-~6zk zIS6;3Tes83d3cESr;84BhOk)QxgML4oho?yV5FI9HKXX*xU8D>ZhMS<`CE3o+kxeu zq5{R%vltnMzLhcYSpVQ<+X+iUzV&Ym*dARpY5J=7N^8g`Rf_0zz*r93oIu*FzJEB4 z6450PfAlIwm%>#5KiZDFW3!9*IPee|8Cs%gL+;P$NWbMSvVXo+awjxPS4h7`8>Y%C zJ8;N^`uL53G`#?q%-q_g8?nZj#is2uq>6gnsJ7s3$t7jL@smq=LF@taJ9W+U%UBWD zp>D%19og)DOL^Hx5CX$aI{Vxm89Bc4deC~TfhAoopZte}>Vs8R=_#wgZOhdXmm3;%zBcJwLya^l--y zyv*pky1k+cdT{9C-7C^Ol7c8SE$FYk%D1QSooecdK#e}06NtmxHe$z; zTcBw8khDc9FwOMApS)Vz=F)k3AgBi%U{#pGSxKyf+8y%e$#Cq$xm+%OYEp?^Q(d#p z-ldyBS`+K%V13>I*RM~>^>3pCPYR0WebEHtrIADUI(Lj z%df|MGPJomJ-wfK_`_GJ zr{=;$pO&Wz5L(u>+B#d4WNp4QYxy@0JbscNCQrXyR14KvgYSaZhb-qx5oX*zB zr-C}>C-9>mkmOBQEIN|5EWuW>sGCYkR%Vdwbxn|@p^AmFnS1wUuGT3e!X}O@>1Fj^ zFC4V#g07BOaAd(n5?j}5H)o^Bz$l{_yb%uF9&+J(cflpI#17v+udy(O=J;BIyp;o& z=FddVyXizfWIUV8bZC|yuJuY7g*TT6@c_WJMTpTQbNJE6)vd0^Y{;YrW{hs}`{w5t z_Wqr{i;TkKvO~=3j+S>de-eME=O`JsB_Fck6STfre`)UhE)m)%@M%;&@Cx1~(L{L0 z;nw(l+eIe$CJ=!QUz}&_NV^Qa!U=8Ur~H_pbOkJR>YE?8?8#^srna2sdrEX)ky_A3 zYAct30n7GNI$=W9I-^FIU7ixmZBV?d zaJ`@;knEW^aORl7rm%QbkVOM1ADx!d0n^PuPb??}ijmpMW>8cRyt50FP&#V0yj4N4A? z98UiFub@R7$)#(1*Zil4J5qI)J_Z**d>4QEGBXZ@hxExDf5UF;i5@7j=bshiXDcN| z=l5f6I`9M;J_*}Il2CrpanOwT8K-f{b(K!3SdqSw#>%XgM!!Oyl|R*zEz+Kr7?+5X zhQ+4yQs=Ku$ZJ>7cMk*FAS!DA+H%DH+~5A&>2DOzq7~pG(8&)Mlclh=s%D+LcsjbU zp9Fj!ya$~sOh3O@cqt|9`bpuj7bNgbBTXr6F)K?J*z4MD<3xDV!(LpR8e99S+CV`Hvlw*zIwNviPNML zdw1m&Ss%s)vzEQ-6VC^(Nk$R3;r;R_<~5OqxJr%Od|I$!!+`13==(1#C-&SBW;DS> zjy`NAE(>K-h2JZmD#27o^GO$GMkH&@^S6yuT6zLVkGB@D0)F;aeU+C&ZM;VW8ABAd z9Z0Qtjl96mS^jEAefaS?o%Y!)dlRKEmX~6Nolry%Ln6{#Id@-OY(VxeTdYyVh+6#O4kbT z%~f320tY|acvp~O%$`o^ZneF4Z_)h1OP;Y5%SXG2`UxWD_vvqaqaWN)s(|&CeI4QH zefr**k*UM(pSX}Ep>pLd+{?#3EqQDMWCODhR$U2*ztBNwO{)R`yRpWR78g6+>?C3@ zf0w`c-HtgL7Dq-!2X)kYmuNJzjn;EOHs*MXIZHNSz1%^6^YpnLXt>a1c(as3d-OJy zT~&2BDCk*{;KSev!KqANC)nCT6aNiK73ITXK!%H+D}WMNRmmKYcNmyNXNCt>KcJVp zlUvE=O~lLDY`{WfpyKy+e*CZOCW>&aj+k7MHP%*eZ%(M=wHr&~NUp&jgxH|M(W zjya2TNdXfI>hCvtM5{=dL0_~~%B~&jWmdoNK_-WagD>0anK>e)b+dOM{=Me8HbQl} zoj@~s+GP>8q81dxK6K}mG@$FuQSo(f;v_upB7)<0)6v{$TD%}5 zDW-gVOtGjKIX??Mv0t2+@r&Q=no;8NP1UwZq+NJ_D>9R&Dh8->e?^`|Z300b@!x{J zlWVrt*(~YroS9^8pgLEYYCzO4$X#m`<9vu0iK4aST$=KGJ-_vg`B7x#xZn@IlP?(MXq~o6$gS`b;eT0$h;$dtCK-3AIjpPf7&YrO zbv*K)wG??`#L58zqDOI)Aqgamg_au|gMqla0zcoMWWKtR;W1*cyG?KZA_!C@ST9(& z-aMc)AIvnhb1I06R!F*TR?%~Zto^8pp~}de7j?~-0OCAL9N(+ms$6KGc9v@{QV9b> zLwz85dW3+!Fyu{_y$J{1Vx!;lXV!Oy=LUCIc+4*Jh;3_j6_NmX{FgoWkvF@>tMSBJ zYyk$WUwVR)Tx1ALMv|DdRlaJ-wt%9tW`k0s03gT~ame=g{k~nLXh8l`A~77jOg<*T zRXL=-NU;S`Njdv38$he5!w!Y(1NE&>Vsv(xuC@hx8RI7!c5IWb<9(pv9e1P0q7Eq9 z9>$iF{Dsy$ig^wOe2}Sny}Hy!c~x;_?`BosjGfQKk$5BL#eDB-$;nPradmpK5*Ofx zkaZ(5NB$hVYxJaPY6G>*1NxX_y80$fIR9 zsLhA158R-5rPdQ==H$hrSo?}b*PFlXE#KESH*%|P9gs$K3aEkhvo6s7Afw^9nwNGum+l; zXsKytiR(R71;R8y=vwd@xIzxp)#N5$<_HI5P_?WEi#OVAWHW`G?dW>*%3|l5EnKFT zErtL^Yly-J_brtd+_ymRMwHFuMyr;(NNxxLC)T=!P5fSkx)wt3aZYq4Aa)n5O`*5^ zS%BkkNRy4rxCsnxeL?EKE=;O%f-t@%MoG0K3JSh4H1>=2H8R_ZGp6R?`b#ud2ht zl9Ohv3Gs8oI4gfn*Wg+M^q}wJhnacB`OLDmY}ix}0-bm0Qnz9%eia>V(#3}k0ignj zkhS`c);fGvVQQGkr?0wmI#lrxeSE<&Vj#Vsun`;r{Bl`#m1A$1w8K zxH~ib!7u!BE!`~1895NtIxX0W-Ty)S=*#QHra&bb58(_#X}C2w>r(Y}cDj$~EXH>Y z0Xrk-y3@bBVBc}pMqYpK3$_G1?CaM0CzJg*N$?0;&)yLu#o?xV2DMhFBcFOJ2@}@6 zowso-*s z2*k%b_E2@r=H`}G65o&{fgGPT)XHUQ2!M5ki_a8nleKw=satx0F?>sEhO)z*EMvPN zq}X&0L<(-OZW|8{^0mrGIY6v$nn9Cw>@rxDPA0(+Q0f%J(%HeOmuCGlL4A43H}isp z(nDdyxZaPdcakryI)6KT&r(I!-23L?fuMZ;m=Nlr+}ES+Fqw;DvC+#WQc(|7veBLJ z(o64Gaw?#74_3J&h`R8dG!}W&`%@iE)e(E-T+vyi)iri17;z<0Ei$aPWf;+(D+}|u zFW5;2ZC^CYO#GaA%z?mU`-1RAbf?>ik^AE{Rl+|Y7b~+k9#)^WU~r@XDw3={(1J%f z1qqPFYt zy$k|MM#+24FLAok)78MK@hZ{9V*C_qxmlbK#~<)3#l4f_d6mg~r&mUxiDVw#PF<xnd4z=OwtNteGhoN4u% z=OI}I9yIRX2`riJq@UGPryg3GT3n_Ren|<5T!~Ho- zkbBKDtUC9dH+=C&KiOTxq;bLR=&XkbnP;=NN5oT58i(v2Q4rh+DsWe)zIoD5kd0lOt+?T$n>Mv*5{PBfM=;+m8%v%YYqR%tm8&JW!u|lMnG^oRDzV;3>eXIy4I`K z7B0}3m}@nJ1az_;7S2bL2|q;};_DknGt_|$>oVM==Dd1MMT6D>_bA@ueI^pByJMUF?P{Vh^Dyoz#67&D`Am zinQ=$ygS4BvrFd4U;+>SBVO83Ii~>9(x-s~RB%3QQ*~wz?V-96$ISbsftB6j<=5S# z^zugxorlH`3Ch)84SP`fw5Qt2ST&&;Ut;bg>*j~?V=E=x(cx*E<)!$k3kLB!LC0xLuPiH&(N#~XzR<_ho0)*gb99_t zJbHL3w(sGXq-QF8x^rZdMzl%d8vNa|jOxLxRrG$fXss=2th|5BtW!7a$9VonrL2t`90Y3+Xk(6^?N~`ZRPm;l|jYLzU=1?_E?&x4;w;%E%-X57Q(p^j)_-U z@NjRMCk9Mc+GnZ42mp5uCsk-h$pkmv`-v(?aYUC#n%>X#-i1r}`=m&(^gbE0?Cd42 zwt{VO^jI49Walamh>JQoS_f?qHhw2oc*zH47xIWRK}B zTpe8Ssl^{+zab3YRXqPDviDl0vh*&f;Nf7SP13k}4H=|HI-W znf3kR@c1Lmcs|8S77|A)T5Xm3{M7v}=JItRozcwv6|RpTU*%aL8?om=YqpIN4yK=t>;X&TVrP9F3S%R z&l0hjtX>8ulF2U@vPe4CSwZ>DclW__x6xlX>w(5@3nePyzXqz*QNGm#!g)htgNT=P zQUB>v4!Gxx8cL=osD?sOn~v{D)=uzoXwKjC>ptYWtk|EfAL+%9offNgih-f)F5+Ot zF2t*;lrs5x)JI)i@*Kji)B{rbvz6Nu8wDPGd^T3T8EEn;h^sPH#ZxE1-m`AAMc+HU zx!w!v{9d$nz=|$S*P-=}QLrof%HYM7&E239l*_EWF;8^#w9lgR5YVn(l~iMZ^pv9@ z#xcDQc;(Z7;!tjq7=&~zb)8W$fy$(~)&hLmjkOR}jrPQTj*c$fq)rG2x(GCQ+K9{4 z3IohHUjucITUn#ZLQ9XyxZwC58! z*zH8j#8j49xOW&-m~}P2z)U}YQm2OL$i5@Xq;u}rr*UD)2vnRvpdnHU3^@{Gca$E^ zn+}Hrvo_bCWH`5(?J=vo)X*w!vO;^8gD5V7aK+ekL3#w5^VL>k#~foZ)AbkMX-#B= zjS{z~w8RunKV>GV1CKsXonu&E+g=fM;E^u~p5jbTOk9niqr;%ZELw5K@t^ z%*Cj(azAf*LJ4lRI!ERd708#%o!CY^iEWU%s6P(xUc4(eg6O=HmE1uTk@nfnTglF9 zGFwoMZ#rNm$(T{G&}~iSMe9$d0U5n(w2pL9OC6ST8F`waZiLMCuO9g7VCwZ*Cnd){ zGL>t0(&5f+Y+|lj-#E<>pbW~{Zc`nfh=hs)*Z5t{$}BjzVAmc2Z?rz~0^N?2!K#|g z{@SW3uK&osL3L}Zj(#8prLtU%pi?o|^x;d~vIW$XiZY%e$$o`XZU#Vq#P+Ov@MJFi z^q|G4!7a7}sh0!US{$CmmQ?Kau~G?{`I_+vg^Kl1b`h(Tw$|ZsQ(09gzW#j>+uwkTxwmgr?HU{kb$$EzBSoc-t!ZHeE2qRN zizGy%M^9p}a+rKn+`^(^Tv>(BmoNs{K4lWU~d$0U(n!(BPu?eSt? z=8};p9EP~q`!W+NIs9KC^`hf!?_!M9dMn$K=`JJEaubZ}#BNV!Z5u`!yA*)9W9LkA zJ~DE@r0wop4Xndwv6WhwujsF|W1r8mHS)x4wng_mo*)Qd_HSD7+lzeS)k1A?W3u~d zVScV*UNiB=Z}Z|f{YBOw+SojuP!ey!!>#t7Ann7VC-IOCIA>O?3mPl8Z+8SFOE@J60vfr6&fyS!+K?N}+4^ z0jB7DaA*5yw#R~}W9gg^b9GF!+RWFMRuGEka%f~npi{s3tOB8YMJe8VMw;ig(979q zjo3@-Z)Rq`G+1e^X-sy(W%D(0XSsoaiFNEDdHQ?RITReN-5wC{5f{&GZ)wU%7zMAO(z4`6WXZ`k; z@0Qr<7#bYt(3Cg~{NXJt2L9!bq}6(`Q|raSD%fH~;Ks#e>4}q6ATuxMV0J}p|3WWK zm|xsIqe*9@`owU477=p7D^`Cj7GvbM8z0H`M-Yc0hjZ< zK=J59iOx>K7aRg>iY4Hj)ZVj(GfwOMc?X?8*H;bIjA_w)LqC>Oj6$(5gLU3duJi?VW+hj$* zpq}|(EBVn%{Oo$%gW=`;f_oq#i_?$y-XFs2?8GXycJti$#l8+JS zwH&g_b|O;kAoNA_yOdK?&eD((>5`Sk`9xO#^M&4@XiKm}B6nQ1_(JL_w78?`Hmw6(MOq20UE%#1~&U ze~l=4%zWjS=Y0UV;+rsXmvJ58V@^|Y3~*eHa$*Uvv!uaa?3==XXPh}9X2LR^6va@I zh;K%Q_KNy`Du~_wHP0E`WC5Ca-r`!UQB2pH?(A@{cbeWmrL{s)!GrCyyLRom_*6}-jlXi9aWfB(9_zwI}_ z8@OzU9d*J#|0+*32;=TA;?Zl9WsbuA^QZna@jq7N6@_K3?H%Q(OdbEPCGHE|e_*d2 z{_4Yn5#Q*uLG51Lih_REp0gBp4(5*G#D&jM+f8^odzL5AaiLy;;v?Z1#(%T5Klc9Z z1-7;JR8~p+FJtt$B}Ac_bHmJ9zU%*~+{%NRJk$6t>mRl&wYX}zcQuH{pdp+BX|N)6 zR0z8U;R6Mcj2wwT&4tL(#qLBz4oeAyD@?iK5dx@@atO?F1Osvu5=uZ|GzSpCA+fR; z)O{o7TjRjan7_XHzM1EH-rw^(p3Ee4`N&H(1PC*yRg8D;Za@ceb~LWAr(FLpWpstZ zjln6&``mxfBO0FBjsiUV4wK;PiXvZyc_^S~RH15P5DVlWEQR?J5#RTVB-lK#JGyt* z`u(N=B*w1Zc7*#S85J{Ggs5R3r!#g|q_3kykh>PNrpfxZEy`}PCEIPG=qCB zZf-UYWPHtAJGddhfIQ#n$P=}IhshH<-9mzp7M@JWYzK;}-g-Yi{-%p3Hy%gJX_ZPj zOjc2b>Bm*prz*BTP?{c30RP~Ay_31WQ!yZ``h|GJw5n} zX73YD={P?5i;0Hf>$-To>j!T<=j`~Mg@mrAIRqzX<+bKgEN)Ljpr=dqry+1C3QXuv zJtOqXUg!PZ*lJ%gNg8Sdjus|6VCBrS+Yz7-Zi7kx(TECyy^lqJ7Yh4R;SA7VyNSkh z*sBxUjScKq^2pA}ok6Kn>GEfTshna3LDT!Ho@|UFR>@30Dz^qhyAt7EG(hI|IrDsna!=^tGE@`!^JplTFdPhH zonlm>+d*!y6S8I_>k9@nkTqIKlpyzp$_r!fd*B$JYIVITEV195YUpf+^%X#!A95o+ zySs@GNnGS!0K?~o-yv04fyb?OIbf*5G@aNYKo;Avmk$$s`7|3kf-3@rA)%g@U zSP4gHp3zShf$HuofWsHHk2O#b9t#zoRxHc2##HWT68MK4l`cPcjSK~sbMx1SA|K4{ z!>8dHYb+CCtOB}4AXIZ$8j6*&>QWO^5-VA1NgwU+S?-=MSfP?h$}L?^ye*%N?KPy> zdupYn=n}SoM}-sXc-kbB6>WVDvqRWosd-qhaft78msg?|r}&JsI<+PEU#2_L6n6@H zBMn!11BIlm0A(n~Z`tTy%eJ+@1-whD^R5TLyB@F2toFSqWGgW^#l+pn{*X^|ndzBZ zTzr||wV@j(!~;8%FcOMxC%BV8wfN#WOCTsp?_y7(BIJi`qq2a9V{~H9lfiaxZ~+!} zC*Hi5xDb||b6P*ZVlEwEf$jboYAsXD-WEUbXig#!Q&#G~BEH#TVqvZv*-{!jKKf~) z7z?Z?(?r9+c$!ALFTP%abzy`lvVyrKz|3P2wXH7~jK!^&MN4CYjY5>d8(0^W@ZlFU zz25-7^a5kEbEO&SZ?zvvD+OJP(lJ)Re0d992N13UNF`MwDB6A<)fV3=y6B4!{z*br zk)nPF+;}nj8yql(BTNA{G{)uMgSkOTs~F~%;VjakMKk7vuz9&A#BR!8L!--CN#nBJ zZD^Os_~Uq)g7{`WF?pep$r5;%!iK9S!rD8CpzBOlF8s7EV44NG?g3p}YD{%2*i_5J zof>ZE$+#x9axLF1Yt0iFHE|^04p9JR1)h?@MO{AdB7se*m)`m8Z3f5ca$Om~;%PIv ziNT(cp}oiIS4CyW9APqW3evM&sWdSZV~h%Qcxq4zZn!R97w$V5I6cINH(vmfj?82i z5cj7;xZ18Tx7C?_)gFSx6t|o!s}rriu^%GcemtEmP$H=k2uRy&$ZY-E_TPex|1egm zn;7PSWVN@ab40+I-0yM$tl&|~NjtuJHC_3PU4m^ZLD!%MR3T(& pHrHe$6Z$hCT0@-%e(#a#w^v>iecn^;pZW*zIp&Q&%0Ci*<-dU7g<}8! literal 0 HcmV?d00001 diff --git a/docs/source/docker/index.rst b/docs/source/docker/index.rst new file mode 100644 index 000000000..2c92a4cbc --- /dev/null +++ b/docs/source/docker/index.rst @@ -0,0 +1,17 @@ +.. _icefall_docker: + +Docker +====== + +This section describes how to use pre-built docker images to run `icefall`_. + +.. hint:: + + If you only have CPUs available, you can still use the pre-built docker + images. + +.. toctree:: + :maxdepth: 2 + + ./intro.rst + diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst new file mode 100644 index 000000000..b09247d85 --- /dev/null +++ b/docs/source/docker/intro.rst @@ -0,0 +1,171 @@ +Introduction +============= + +We have pre-built docker images hosted at the following address: + + ``_ + +.. figure:: img/docker-hub.png + :width: 600 + :align: center + +You can find the ``Dockerfile`` at ``_. + +We describe the following items in this section: + + - How to view available tags + - How to download pre-built docker images + - How to run the `yesno`_ recipe within a docker container on ``CPU`` + +View available tags +=================== + +You can use the following command to view available tags: + +.. code-block:: bash + + curl -s 'https://registry.hub.docker.com/v2/repositories/k2fsa/icefall/tags/'|jq '."results"[]["name"]' + +which will give you something like below: + +.. code-block:: bash + + "torch2.0.0-cuda11.7" + "torch1.12.1-cuda11.3" + "torch1.9.0-cuda10.2" + "torch1.13.0-cuda11.6" + +.. hint:: + + Available tags will be updated when there are new releases of `torch`_. + +Please select an appropriate combination of `torch`_ and CUDA. + +Download a docker image +======================= + +Suppose that you select the tag ``torch1.13.0-cuda11.6``, you can use +the following command to download it: + +.. code-block:: bash + + sudo docker image pull k2fsa/icefall:torch1.13.0-cuda11.6 + +Run a docker image with GPU +=========================== + +.. code-block:: bash + + sudo docker run --gpus all --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash + +Run a docker image with CPU +=========================== + +.. code-block:: bash + + sudo docker run --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash + +Run yesno within a docker container +=================================== + +After starting the container, the following interface is presented: + +.. code-block:: bash + + root@60c947eac59c:/workspace/icefall# + +It shows the current user is ``root`` and the current working directory +is ``/workspace/icefall``. + +Update the code +--------------- + +Please first run: + +.. code-block:: bash + + root@60c947eac59c:/workspace/icefall# git pull + +so that your local copy contains the latest code. + +Data preparation +---------------- + +Now we can use + +.. code-block:: bash + + root@60c947eac59c:/workspace/icefall# cd egs/yesno/ASR/ + +to switch to the ``yesno`` recipe and run + +.. code-block:: bash + + root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./prepare.sh + +.. hint:: + + If you are running without GPU, it may report the following error: + + .. code-block:: bash + + File "/opt/conda/lib/python3.9/site-packages/k2/__init__.py", line 23, in + from _k2 import DeterminizeWeightPushingType + ImportError: libcuda.so.1: cannot open shared object file: No such file or directory + + We can use the following command to fix it: + + .. code-block:: bash + + root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ln -s /opt/conda/lib/stubs/libcuda.so /opt/conda/lib/stubs/libcuda.so.1 + +The logs of running ``./prepare.sh`` are listed below: + +.. literalinclude:: ./log/log-preparation.txt + +Training +-------- + +After preparing the data, we can start training with the following command + +.. code-block:: bash + + root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./tdnn/train.py + +All of the training logs are given below: + +.. hint:: + + It is running on CPU and it takes only 16 seconds for this run. + +.. literalinclude:: ./log/log-train-2023-08-01-01-55-27 + + +Decoding +-------- + +After training, we can decode the trained model with + +.. code-block:: bash + + root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./tdnn/decode.py + +The decoding logs are given below: + +.. code-block:: bash + + 2023-08-01 02:06:22,400 INFO [decode.py:263] Decoding started + 2023-08-01 02:06:22,400 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4c05309499a08454997adf500b56dcc629e35ae5', 'k2-git-date': 'Tue Jul 25 16:23:36 2023', 'lhotse-version': '1.16.0.dev+git.7640d663.clean', 'torch-version': '1.13.0', 'torch-cuda-available': False, 'torch-cuda-version': '11.6', 'python-version': '3.9', 'icefall-git-branch': 'master', 'icefall-git-sha1': '375520d-clean', 'icefall-git-date': 'Fri Jul 28 07:43:08 2023', 'icefall-path': '/workspace/icefall', 'k2-path': '/opt/conda/lib/python3.9/site-packages/k2/__init__.py', 'lhotse-path': '/opt/conda/lib/python3.9/site-packages/lhotse/__init__.py', 'hostname': '60c947eac59c', 'IP address': '172.17.0.2'}} + 2023-08-01 02:06:22,401 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt + 2023-08-01 02:06:22,403 INFO [decode.py:273] device: cpu + 2023-08-01 02:06:22,406 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt'] + 2023-08-01 02:06:22,424 INFO [asr_datamodule.py:218] About to get test cuts + 2023-08-01 02:06:22,425 INFO [asr_datamodule.py:252] About to get test cuts + 2023-08-01 02:06:22,504 INFO [decode.py:204] batch 0/?, cuts processed until now is 4 + [W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware. + 2023-08-01 02:06:22,687 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt + 2023-08-01 02:06:22,688 INFO [utils.py:564] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ] + 2023-08-01 02:06:22,690 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt + 2023-08-01 02:06:22,690 INFO [decode.py:316] Done! + +Congratulations! You have finished successfully running `icefall`_ within a docker container. diff --git a/docs/source/index.rst b/docs/source/index.rst index a7d365a15..0fa8fdd1c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,9 +21,11 @@ speech recognition recipes using `k2 `_. :caption: Contents: installation/index + docker/index faqs model-export/index + .. toctree:: :maxdepth: 3 @@ -38,4 +40,4 @@ speech recognition recipes using `k2 `_. .. toctree:: :maxdepth: 2 - decoding-with-langugage-models/index \ No newline at end of file + decoding-with-langugage-models/index diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index 534b674f9..5a034ef5b 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -3,6 +3,11 @@ Installation ============ +.. hint:: + + We also provide :ref:`icefall_docker` support, which has already setup + the environment for you. + .. hint:: We have a colab notebook guiding you step by step to setup the environment. From 1ee251c8b385f6dcf06da40b1760b76496b0d812 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 3 Aug 2023 15:50:35 +0800 Subject: [PATCH 014/113] Decode zipformer with external LMs (#1193) * update some documentation * support decoding with LMs in zipformer recipe * update RESULTS.md --- .../decoding-with-langugage-models/LODR.rst | 54 ++--- .../rescoring.rst | 6 +- .../shallow-fusion.rst | 4 +- egs/librispeech/ASR/RESULTS.md | 7 + .../decode.py | 7 + egs/librispeech/ASR/zipformer/decode.py | 216 ++++++++++++++++-- 6 files changed, 238 insertions(+), 56 deletions(-) diff --git a/docs/source/decoding-with-langugage-models/LODR.rst b/docs/source/decoding-with-langugage-models/LODR.rst index 7ffa0c128..b6625ee1d 100644 --- a/docs/source/decoding-with-langugage-models/LODR.rst +++ b/docs/source/decoding-with-langugage-models/LODR.rst @@ -4,59 +4,59 @@ LODR for RNN Transducer ======================= -As a type of E2E model, neural transducers are usually considered as having an internal -language model, which learns the language level information on the training corpus. -In real-life scenario, there is often a mismatch between the training corpus and the target corpus space. +As a type of E2E model, neural transducers are usually considered as having an internal +language model, which learns the language level information on the training corpus. +In real-life scenario, there is often a mismatch between the training corpus and the target corpus space. This mismatch can be a problem when decoding for neural transducer models with language models as its internal language can act "against" the external LM. In this tutorial, we show how to use `Low-order Density Ratio `_ to alleviate this effect to further improve the performance -of langugae model integration. +of langugae model integration. .. note:: - This tutorial is based on the recipe + This tutorial is based on the recipe `pruned_transducer_stateless7_streaming `_, - which is a streaming transducer model trained on `LibriSpeech`_. + which is a streaming transducer model trained on `LibriSpeech`_. However, you can easily apply LODR to other recipes. If you encounter any problems, please open an issue here `icefall `__. .. note:: - For simplicity, the training and testing corpus in this tutorial are the same (`LibriSpeech`_). However, - you can change the testing set to any other domains (e.g `GigaSpeech`_) and prepare the language models + For simplicity, the training and testing corpus in this tutorial are the same (`LibriSpeech`_). However, + you can change the testing set to any other domains (e.g `GigaSpeech`_) and prepare the language models using that corpus. -First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here `_ +First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here `_ to address the language information mismatch between the training corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain are acoustically similar, DR derives the following formular for decoding with Bayes' theorem: .. math:: - \text{score}\left(y_u|\mathit{x},y\right) = - \log p\left(y_u|\mathit{x},y_{1:u-1}\right) + - \lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) - + \text{score}\left(y_u|\mathit{x},y\right) = + \log p\left(y_u|\mathit{x},y_{1:u-1}\right) + + \lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) - \lambda_2 \log p_{\text{Source LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively. -Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to +where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively. +Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to shallow fusion is the subtraction of the source domain LM. -Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is +Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is considered to be weak and can only capture low-level language information. Therefore, `LODR `__ proposed to use a low-order n-gram LM as an approximation of the ILM of the neural transducer. This leads to the following formula during decoding for transducer model: .. math:: - \text{score}\left(y_u|\mathit{x},y\right) = - \log p_{rnnt}\left(y_u|\mathit{x},y_{1:u-1}\right) + - \lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) - + \text{score}\left(y_u|\mathit{x},y\right) = + \log p_{rnnt}\left(y_u|\mathit{x},y_{1:u-1}\right) + + \lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) - \lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right) -In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR, +In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR, the only difference lies in the choice of source domain LM. According to the original `paper `_, LODR achieves similar performance compared DR in both intra-domain and cross-domain settings. As a bi-gram is much faster to evaluate, LODR is usually much faster. @@ -85,7 +85,7 @@ To test the model, let's have a look at the decoding results **without** using L --avg 1 \ --use-averaged-model False \ --exp-dir $exp_dir \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \ --max-duration 600 \ --decode-chunk-len 32 \ --decoding-method modified_beam_search @@ -99,17 +99,17 @@ The following WERs are achieved on test-clean and test-other: $ For test-other, WER of different settings are: $ beam_size_4 7.93 best for test-other -Then, we download the external language model and bi-gram LM that are necessary for LODR. +Then, we download the external language model and bi-gram LM that are necessary for LODR. Note that the bi-gram is estimated on the LibriSpeech 960 hours' text. .. code-block:: bash $ # download the external LM - $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm $ # create a symbolic link so that the checkpoint can be loaded $ pushd icefall-librispeech-rnn-lm/exp $ git lfs pull --include "pretrained.pt" - $ ln -s pretrained.pt epoch-99.pt + $ ln -s pretrained.pt epoch-99.pt $ popd $ $ # download the bi-gram @@ -122,7 +122,7 @@ Note that the bi-gram is estimated on the LibriSpeech 960 hours' text. Then, we perform LODR decoding by setting ``--decoding-method`` to ``modified_beam_search_lm_LODR``: .. code-block:: bash - + $ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp $ lm_dir=./icefall-librispeech-rnn-lm/exp $ lm_scale=0.42 @@ -135,8 +135,8 @@ Then, we perform LODR decoding by setting ``--decoding-method`` to ``modified_be --exp-dir $exp_dir \ --max-duration 600 \ --decode-chunk-len 32 \ - --decoding-method modified_beam_search_lm_LODR \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --decoding-method modified_beam_search_LODR \ + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \ --use-shallow-fusion 1 \ --lm-type rnn \ --lm-exp-dir $lm_dir \ @@ -181,4 +181,4 @@ indeed **further improves** the WER. We can do even better if we increase ``--be - 6.38 * - 12 - 2.4 - - 6.23 \ No newline at end of file + - 6.23 diff --git a/docs/source/decoding-with-langugage-models/rescoring.rst b/docs/source/decoding-with-langugage-models/rescoring.rst index ee2e2113c..02eba9129 100644 --- a/docs/source/decoding-with-langugage-models/rescoring.rst +++ b/docs/source/decoding-with-langugage-models/rescoring.rst @@ -48,7 +48,7 @@ As usual, we first test the model's performance without external LM. This can be --avg 1 \ --use-averaged-model False \ --exp-dir $exp_dir \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \ --max-duration 600 \ --decode-chunk-len 32 \ --decoding-method modified_beam_search @@ -101,7 +101,7 @@ is set to `False`. --max-duration 600 \ --decode-chunk-len 32 \ --decoding-method modified_beam_search_lm_rescore \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \ --use-shallow-fusion 0 \ --lm-type rnn \ --lm-exp-dir $lm_dir \ @@ -173,7 +173,7 @@ Then we can performn LM rescoring + LODR by changing the decoding method to `mod --max-duration 600 \ --decode-chunk-len 32 \ --decoding-method modified_beam_search_lm_rescore_LODR \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \ --use-shallow-fusion 0 \ --lm-type rnn \ --lm-exp-dir $lm_dir \ diff --git a/docs/source/decoding-with-langugage-models/shallow-fusion.rst b/docs/source/decoding-with-langugage-models/shallow-fusion.rst index 0d2837372..f15e3f1d9 100644 --- a/docs/source/decoding-with-langugage-models/shallow-fusion.rst +++ b/docs/source/decoding-with-langugage-models/shallow-fusion.rst @@ -46,7 +46,7 @@ To test the model, let's have a look at the decoding results without using LM. T --avg 1 \ --use-averaged-model False \ --exp-dir $exp_dir \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \ --max-duration 600 \ --decode-chunk-len 32 \ --decoding-method modified_beam_search @@ -95,7 +95,7 @@ To use shallow fusion for decoding, we can execute the following command: --max-duration 600 \ --decode-chunk-len 32 \ --decoding-method modified_beam_search_lm_shallow_fusion \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \ --use-shallow-fusion 1 \ --lm-type rnn \ --lm-exp-dir $lm_dir \ diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 1b8e690bd..b945f43fd 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -90,6 +90,11 @@ You can use to deploy it. | greedy_search | 2.23 | 4.96 | --epoch 40 --avg 16 | | modified_beam_search | 2.21 | 4.91 | --epoch 40 --avg 16 | | fast_beam_search | 2.24 | 4.93 | --epoch 40 --avg 16 | +| modified_beam_search_shallow_fusion | 2.01 | 4.37 | --epoch 40 --avg 16 --beam-size 12 --lm-scale 0.3 | +| modified_beam_search_LODR | 1.94 | 4.17 | --epoch 40 --avg 16 --beam-size 12 --lm-scale 0.52 --LODR-scale -0.26 | +| modified_beam_search_rescore | 2.04 | 4.39 | --epoch 40 --avg 16 --beam-size 12 | +| modified_beam_search_rescore_LODR | 2.01 | 4.33 | --epoch 40 --avg 16 --beam-size 12 | + The training command is: ```bash @@ -119,6 +124,8 @@ for m in greedy_search modified_beam_search fast_beam_search; do done ``` +To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html). + ##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M The tensorboard log can be found at diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py index 3444f8193..02029c108 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -396,6 +396,12 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -907,6 +913,7 @@ def main(): ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") logging.info(f"lm filename: {ngram_file_name}") ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search elif params.decoding_method == "modified_beam_search_LODR": lm_filename = f"{params.tokens_ngram}gram.fst.txt" diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 93680602e..2cc157e7a 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -115,9 +115,14 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, ) -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -273,8 +278,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -302,6 +306,47 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + add_model_arguments(parser) return parser @@ -314,6 +359,9 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -342,6 +390,12 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -425,10 +479,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -445,6 +496,50 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) else: batch_size = encoder_out.size(0) @@ -483,6 +578,16 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} + elif params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"beam_size_{params.beam_size}_{key}"] = hyps + return ans else: return {f"beam_size_{params.beam_size}": hyps} @@ -494,6 +599,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -543,6 +651,9 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) for name, hyps in hyps_dict.items(): @@ -559,9 +670,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -594,8 +703,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -614,6 +722,7 @@ def save_results( def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -628,6 +737,10 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", ) params.res_dir = params.exp_dir / params.decoding_method @@ -656,13 +769,19 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -690,9 +809,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -719,9 +838,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -768,6 +887,54 @@ def main(): model.to(device) model.eval() + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -780,9 +947,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None @@ -811,6 +976,9 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) save_results( From 00256a766921dd34a267012b0e2b8ff7d538f0e6 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 9 Aug 2023 09:40:58 +0800 Subject: [PATCH 015/113] Fix decode_stream.py (#1208) * FIx decode_stream.py * Update decode_stream.py --- egs/librispeech/ASR/zipformer/decode_stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/decode_stream.py b/egs/librispeech/ASR/zipformer/decode_stream.py index 946db275c..d6918bf32 100644 --- a/egs/librispeech/ASR/zipformer/decode_stream.py +++ b/egs/librispeech/ASR/zipformer/decode_stream.py @@ -79,12 +79,12 @@ class DecodeStream(object): self.pad_length = 7 + 2 * 3 if params.decoding_method == "greedy_search": - self.hyp = [params.blank_id] * params.context_size + self.hyp = [-1] * (params.context_size - 1) + [params.blank_id] elif params.decoding_method == "modified_beam_search": self.hyps = HypothesisList() self.hyps.add( Hypothesis( - ys=[params.blank_id] * params.context_size, + ys=[-1] * (params.context_size - 1) + [params.blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), ) ) From 74806b744b81620d06645c27f5a2dda307e58322 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 10 Aug 2023 20:56:02 +0800 Subject: [PATCH 016/113] disable speed perturbation by default (#1176) * disable speed perturbation by default * minor fixes * minor updates * updated bash scripts to incorporate with the `speed-perturb` arg * minor fixes 1. changed the naming scheme from `speed-perturb` to `perturb-speed` to align with the librispeech recipe >> https://github.com/k2-fsa/icefall/blob/00256a766921dd34a267012b0e2b8ff7d538f0e6/egs/librispeech/ASR/local/compute_fbank_librispeech.py#L65 2. changed arg type for `perturb-speed` to str2bool --- .../local/compute_fbank_aidatatang_200zh.py | 18 ++++++++--- egs/aidatatang_200zh/ASR/prepare.sh | 2 +- .../local/compute_fbank_aidatatang_200zh.py | 18 ++++++++--- .../ASR/local/compute_fbank_aishell.py | 18 ++++++++--- egs/aishell/ASR/prepare.sh | 2 +- egs/aishell/ASR/prepare_aidatatang_200zh.sh | 2 +- .../ASR/local/compute_fbank_aishell2.py | 17 +++++++--- egs/aishell2/ASR/prepare.sh | 2 +- .../ASR/local/compute_fbank_aishell4.py | 18 ++++++++--- egs/aishell4/ASR/prepare.sh | 2 +- .../ASR/local/compute_fbank_alimeeting.py | 17 +++++++--- egs/alimeeting/ASR/prepare.sh | 2 +- .../ASR_v2/local/compute_fbank_alimeeting.py | 32 ++++++++++++++++--- egs/alimeeting/ASR_v2/prepare.sh | 2 +- .../ASR/local/preprocess_wenetspeech.py | 20 ++++++++++-- egs/wenetspeech/ASR/prepare.sh | 2 +- 16 files changed, 132 insertions(+), 42 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py index 387c14acf..9caacb78b 100755 --- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py @@ -32,7 +32,7 @@ import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): +def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, perturb_speed: bool = False): src_dir = Path("data/manifests/aidatatang_200zh") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -85,7 +85,8 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): recordings=m["recordings"], supervisions=m["supervisions"], ) - if "train" in partition: + if "train" in partition and perturb_speed: + logging.info(f"Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -109,7 +110,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) - + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -119,4 +125,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins) + compute_fbank_aidatatang_200zh( + num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + ) diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh index 46ecd5769..2eb0b3718 100755 --- a/egs/aidatatang_200zh/ASR/prepare.sh +++ b/egs/aidatatang_200zh/ASR/prepare.sh @@ -77,7 +77,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Compute fbank for aidatatang_200zh" if [ ! -f data/fbank/.aidatatang_200zh.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aidatatang_200zh.py + ./local/compute_fbank_aidatatang_200zh.py --perturb-speed True touch data/fbank/.aidatatang_200zh.done fi fi diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py index 037971927..6a9bb4f42 100755 --- a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py @@ -32,7 +32,7 @@ import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): +def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, perturb_speed: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -85,7 +85,8 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): recordings=m["recordings"], supervisions=m["supervisions"], ) - if "train" in partition: + if "train" in partition and perturb_speed: + logging.info(f"Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -109,7 +110,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) - + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -119,4 +125,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins) + compute_fbank_aidatatang_200zh( + num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + ) diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index 115ca1031..c7000da1c 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -32,7 +32,7 @@ import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell(num_mel_bins: int = 80): +def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -81,7 +81,8 @@ def compute_fbank_aishell(num_mel_bins: int = 80): recordings=m["recordings"], supervisions=m["supervisions"], ) - if "train" in partition: + if "train" in partition and perturb_speed: + logging.info(f"Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -104,7 +105,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) - + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -114,4 +120,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_aishell(num_mel_bins=args.num_mel_bins) + compute_fbank_aishell( + num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + ) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index b763d72c1..ff8e1301d 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -114,7 +114,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute fbank for aishell" if [ ! -f data/fbank/.aishell.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aishell.py + ./local/compute_fbank_aishell.py --perturb-speed True touch data/fbank/.aishell.done fi fi diff --git a/egs/aishell/ASR/prepare_aidatatang_200zh.sh b/egs/aishell/ASR/prepare_aidatatang_200zh.sh index f1d4d18a7..ec89450df 100755 --- a/egs/aishell/ASR/prepare_aidatatang_200zh.sh +++ b/egs/aishell/ASR/prepare_aidatatang_200zh.sh @@ -53,7 +53,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Process aidatatang_200zh" if [ ! -f data/fbank/.aidatatang_200zh_fbank.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aidatatang_200zh.py + ./local/compute_fbank_aidatatang_200zh.py --perturb-speed True touch data/fbank/.aidatatang_200zh_fbank.done fi fi diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py index ec0c584ca..1fb1621ff 100755 --- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py +++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py @@ -32,7 +32,7 @@ import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell2(num_mel_bins: int = 80): +def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -81,7 +81,8 @@ def compute_fbank_aishell2(num_mel_bins: int = 80): recordings=m["recordings"], supervisions=m["supervisions"], ) - if "train" in partition: + if "train" in partition and perturb_speed: + logging.info(f"Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -104,6 +105,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -114,4 +121,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_aishell2(num_mel_bins=args.num_mel_bins) + compute_fbank_aishell2( + num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + ) diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh index 3e8e840ab..42631c864 100755 --- a/egs/aishell2/ASR/prepare.sh +++ b/egs/aishell2/ASR/prepare.sh @@ -101,7 +101,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute fbank for aishell2" if [ ! -f data/fbank/.aishell2.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aishell2.py + ./local/compute_fbank_aishell2.py --perturb-speed True touch data/fbank/.aishell2.done fi fi diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py index 400c406f0..f19163988 100755 --- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py +++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py @@ -32,7 +32,7 @@ import torch from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell4(num_mel_bins: int = 80): +def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False): src_dir = Path("data/manifests/aishell4") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -83,10 +83,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80): recordings=m["recordings"], supervisions=m["supervisions"], ) - if "train" in partition: + if "train" in partition and perturb_speed: + logging.info(f"Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) + cut_set = cut_set.compute_and_store_features( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{partition}", @@ -113,6 +115,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -123,4 +131,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_aishell4(num_mel_bins=args.num_mel_bins) + compute_fbank_aishell4( + num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + ) diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh index cb2b73a3e..1b1ec0005 100755 --- a/egs/aishell4/ASR/prepare.sh +++ b/egs/aishell4/ASR/prepare.sh @@ -107,7 +107,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Compute fbank for aishell4" if [ ! -f data/fbank/.aishell4.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aishell4.py + ./local/compute_fbank_aishell4.py --perturb-speed True touch data/fbank/.aishell4.done fi fi diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py index 96115a230..f8c10648a 100755 --- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py +++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py @@ -32,7 +32,7 @@ import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_alimeeting(num_mel_bins: int = 80): +def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False): src_dir = Path("data/manifests/alimeeting") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -82,7 +82,8 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80): recordings=m["recordings"], supervisions=m["supervisions"], ) - if "train" in partition: + if "train" in partition and perturb_speed: + logging.info(f"Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -114,6 +115,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -124,4 +131,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_alimeeting(num_mel_bins=args.num_mel_bins) + compute_fbank_alimeeting( + num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + ) diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh index 604cc92c6..1709733c7 100755 --- a/egs/alimeeting/ASR/prepare.sh +++ b/egs/alimeeting/ASR/prepare.sh @@ -97,7 +97,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Compute fbank for alimeeting" if [ ! -f data/fbank/.alimeeting.done ]; then mkdir -p data/fbank - ./local/compute_fbank_alimeeting.py + ./local/compute_fbank_alimeeting.py --perturb-speed True touch data/fbank/.alimeeting.done fi fi diff --git a/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py index c6aa2ab36..833d11c72 100755 --- a/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py +++ b/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py @@ -25,6 +25,7 @@ It looks for manifests in the directory data/manifests. The generated fbank features are saved in data/fbank. """ +import argparse import logging from pathlib import Path @@ -39,6 +40,8 @@ from lhotse.features.kaldifeat import ( ) from lhotse.recipes.utils import read_manifests_if_cached +from icefall.utils import str2bool + # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect @@ -48,7 +51,7 @@ torch.set_num_interop_threads(1) torch.multiprocessing.set_sharing_strategy("file_system") -def compute_fbank_ami(): +def compute_fbank_ami(perturb_speed: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") @@ -84,8 +87,12 @@ def compute_fbank_ami(): suffix="jsonl.gz", ) - def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None: - cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1) + def _extract_feats( + cuts: CutSet, storage_path: Path, manifest_path: Path, speed_perturb: bool + ) -> None: + if speed_perturb: + logging.info(f"Doing speed perturb") + cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1) _ = cuts.compute_and_store_features_batch( extractor=extractor, storage_path=storage_path, @@ -109,6 +116,7 @@ def compute_fbank_ami(): cuts_ihm, output_dir / "feats_train_ihm", src_dir / "cuts_train_ihm.jsonl.gz", + perturb_speed, ) logging.info("Processing train split IHM + reverberated IHM") @@ -117,6 +125,7 @@ def compute_fbank_ami(): cuts_ihm_rvb, output_dir / "feats_train_ihm_rvb", src_dir / "cuts_train_ihm_rvb.jsonl.gz", + perturb_speed, ) logging.info("Processing train split SDM") @@ -129,6 +138,7 @@ def compute_fbank_ami(): cuts_sdm, output_dir / "feats_train_sdm", src_dir / "cuts_train_sdm.jsonl.gz", + perturb_speed, ) logging.info("Processing train split GSS") @@ -141,6 +151,7 @@ def compute_fbank_ami(): cuts_gss, output_dir / "feats_train_gss", src_dir / "cuts_train_gss.jsonl.gz", + perturb_speed, ) logging.info("Preparing test cuts: IHM, SDM, GSS (optional)") @@ -186,8 +197,21 @@ def compute_fbank_ami(): ) +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + return parser.parse_args() + + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_ami() + args = get_args() + + compute_fbank_ami(perturb_speed=args.perturb_speed) diff --git a/egs/alimeeting/ASR_v2/prepare.sh b/egs/alimeeting/ASR_v2/prepare.sh index 76a108771..1098840f8 100755 --- a/egs/alimeeting/ASR_v2/prepare.sh +++ b/egs/alimeeting/ASR_v2/prepare.sh @@ -85,7 +85,7 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Compute fbank for alimeeting" mkdir -p data/fbank - python local/compute_fbank_alimeeting.py + python local/compute_fbank_alimeeting.py --perturb-speed True log "Combine features from train splits" lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\ gzip -c > data/manifests/cuts_train_all.jsonl.gz diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py index 93ce750f8..5de3c23a9 100755 --- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py +++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import logging import re from pathlib import Path @@ -24,6 +25,7 @@ from lhotse import CutSet, SupervisionSegment from lhotse.recipes.utils import read_manifests_if_cached from icefall import setup_logger +from icefall.utils import str2bool # Similar text filtering and normalization procedure as in: # https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh @@ -45,7 +47,7 @@ def has_no_oov( return oov_pattern.search(sup.text) is None -def preprocess_wenet_speech(): +def preprocess_wenet_speech(perturb_speed: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") output_dir.mkdir(exist_ok=True) @@ -110,7 +112,7 @@ def preprocess_wenet_speech(): ) # Run data augmentation that needs to be done in the # time domain. - if partition not in ["DEV", "TEST_NET", "TEST_MEETING"]: + if partition not in ["DEV", "TEST_NET", "TEST_MEETING"] and perturb_speed: logging.info( f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" @@ -120,10 +122,22 @@ def preprocess_wenet_speech(): cut_set.to_file(raw_cuts_path) +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + return parser.parse_args() + + def main(): setup_logger(log_filename="./log-preprocess-wenetspeech") - preprocess_wenet_speech() + args = get_args() + preprocess_wenet_speech(perturb_speed=args.perturb_speed) logging.info("Done") diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index f7b521794..097a59a5f 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -91,7 +91,7 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Preprocess WenetSpeech manifest" if [ ! -f data/fbank/.preprocess_complete ]; then - python3 ./local/preprocess_wenetspeech.py + python3 ./local/preprocess_wenetspeech.py --perturb-speed True touch data/fbank/.preprocess_complete fi fi From d6b28a11a70871a76b66ccf80667dd1d3ac1ab17 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 11 Aug 2023 23:57:00 +0800 Subject: [PATCH 017/113] Add export script for the yesno recipe. (#1212) --- .github/workflows/run-yesno-recipe.yml | 76 +++++++- egs/yesno/ASR/tdnn/decode.py | 1 - egs/yesno/ASR/tdnn/export.py | 118 ++++++++++++ egs/yesno/ASR/tdnn/export_onnx.py | 158 ++++++++++++++++ egs/yesno/ASR/tdnn/jit_pretrained.py | 199 ++++++++++++++++++++ egs/yesno/ASR/tdnn/onnx_pretrained.py | 241 +++++++++++++++++++++++++ egs/yesno/ASR/tdnn/pretrained.py | 37 +++- 7 files changed, 813 insertions(+), 17 deletions(-) create mode 100755 egs/yesno/ASR/tdnn/export.py create mode 100755 egs/yesno/ASR/tdnn/export_onnx.py create mode 100755 egs/yesno/ASR/tdnn/jit_pretrained.py create mode 100755 egs/yesno/ASR/tdnn/onnx_pretrained.py diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 8a2c94829..57f15fe87 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -44,11 +44,6 @@ jobs: with: fetch-depth: 0 - - name: Install graphviz - shell: bash - run: | - sudo apt-get -qq install graphviz - - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -70,6 +65,7 @@ jobs: pip install --no-binary protobuf protobuf==3.20.* pip install --no-deps --force-reinstall https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.3.dev20230508+cpu.torch1.13.1-cp38-cp38-linux_x86_64.whl + pip install kaldifeat==1.25.0.dev20230726+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html - name: Run yesno recipe shell: bash @@ -78,9 +74,75 @@ jobs: export PYTHONPATH=$PWD:$PYTHONPATH echo $PYTHONPATH - cd egs/yesno/ASR ./prepare.sh python3 ./tdnn/train.py python3 ./tdnn/decode.py - # TODO: Check that the WER is less than some value + + - name: Test exporting to pretrained.pt + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export.py --epoch 14 --avg 2 + + python3 ./tdnn/pretrained.py \ + --checkpoint ./tdnn/exp/pretrained.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + + - name: Test exporting to torchscript + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + + python3 ./tdnn/jit_pretrained.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + + - name: Test exporting to onnx + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export_onnx.py --epoch 14 --avg 2 + + echo "Test float32 model" + python3 ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + + + echo "Test int8 model" + python3 ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + + - name: Show generated files + shell: bash + working-directory: ${{github.workspace}} + run: | + cd egs/yesno/ASR + ls -lh tdnn/exp diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index d5efb41df..f520607af 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -65,7 +65,6 @@ def get_params() -> AttributeDict: { "exp_dir": Path("tdnn/exp/"), "lang_dir": Path("data/lang_phone"), - "lm_dir": Path("data/lm"), "feature_dim": 23, "search_beam": 20, "output_beam": 8, diff --git a/egs/yesno/ASR/tdnn/export.py b/egs/yesno/ASR/tdnn/export.py new file mode 100755 index 000000000..c40cf8cd1 --- /dev/null +++ b/egs/yesno/ASR/tdnn/export.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +""" +This file is for exporting trained models to a checkpoint +or to a torchscript model. + +(1) Generate the checkpoint tdnn/exp/pretrained.pt + +./tdnn/export.py \ + --epoch 14 \ + --avg 2 + +See ./tdnn/pretrained.py for how to use the generated file. + +(2) Generate torchscript model tdnn/exp/cpu_jit.pt + +./tdnn/export.py \ + --epoch 14 \ + --avg 2 \ + --jit 1 + +See ./tdnn/jit_pretrained.py for how to use the generated file. +""" + +import argparse +import logging + +import torch +from model import Tdnn +from train import get_params + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=14, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=2, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + + model = Tdnn( + num_features=params.feature_dim, + num_classes=max_token_id + 1, # +1 for the blank symbol + ) + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/export_onnx.py b/egs/yesno/ASR/tdnn/export_onnx.py new file mode 100755 index 000000000..9b2a56d59 --- /dev/null +++ b/egs/yesno/ASR/tdnn/export_onnx.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 + +""" +This file is for exporting trained models to onnx. + +Usage: + + ./tdnn/export_onnx.py \ + --epoch 14 \ + --avg 2 + +The above command generates the following two files: + - ./exp/model-epoch-14-avg-2.onnx + - ./exp/model-epoch-14-avg-2.int8.onnx + +See ./tdnn/onnx_pretrained.py for how to use them. +""" + +import argparse +import logging +from typing import Dict + +import onnx +import torch +from model import Tdnn +from onnxruntime.quantization import QuantType, quantize_dynamic +from train import get_params + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=14, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=2, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + + model = Tdnn( + num_features=params.feature_dim, + num_classes=max_token_id + 1, # +1 for the blank symbol + ) + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + N = 1 + T = 100 + C = params.feature_dim + x = torch.rand(N, T, C) + + opset_version = 13 + onnx_filename = f"{params.exp_dir}/model-epoch-{params.epoch}-avg-{params.avg}.onnx" + torch.onnx.export( + model, + x, + onnx_filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["log_prob"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "log_prob": {0: "N", 1: "T"}, + }, + ) + + logging.info(f"Saved to {onnx_filename}") + meta_data = { + "model_type": "tdnn_lstm", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming tdnn for the yesno recipe", + "vocab_size": max_token_id + 1, + } + + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=onnx_filename, meta_data=meta_data) + + logging.info("Generate int8 quantization models") + onnx_filename_int8 = ( + f"{params.exp_dir}/model-epoch-{params.epoch}-avg-{params.avg}.int8.onnx" + ) + + quantize_dynamic( + model_input=onnx_filename, + model_output=onnx_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + logging.info(f"Saved to {onnx_filename_int8}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py new file mode 100755 index 000000000..84390fca5 --- /dev/null +++ b/egs/yesno/ASR/tdnn/jit_pretrained.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a torchscript model for decoding. + +Usage: + + ./tdnn/jit_pretrained.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +Note that to generate ./tdnn/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +from typing import List +import math + + +import k2 +import kaldifeat +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import get_lattice, one_best_decoding +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./tdnn/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 23, + "num_classes": 4, # [, N, SIL, Y] + "sample_rate": 8000, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + # Note: We don't use key padding mask for attention during decoding + nnet_output = model(features) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py new file mode 100755 index 000000000..626473b6e --- /dev/null +++ b/egs/yesno/ASR/tdnn/onnx_pretrained.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use an ONNX model for decoding with onnxruntime. + +Usage: + +(1) Use a not quantized ONNX model, i.e., a float32 model + ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +(2) Use a quantized ONNX model, i.e., an int8 model + + ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +Note that to generate ./tdnn/exp/model-epoch-14-avg-2.onnx, +and ./tdnn/exp/model-epoch-14-avg-2.onnx, +you can use ./export_onnx.py --epoch 14 --avg 2 +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import get_lattice, one_best_decoding +from icefall.utils import AttributeDict, get_texts + + +class OnnxModel: + def __init__(self, nn_model: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + ) + + meta = self.model.get_modelmeta().custom_metadata_map + self.vocab_size = int(meta["vocab_size"]) + + def run( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor log_prob of shape (N, T, C) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + }, + ) + return torch.from_numpy(out[0]) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./tdnn/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 23, + "sample_rate": 8000, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def main(): + parser = get_parser() + args = parser.parse_args() + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"device: {device}") + + logging.info(f"Loading onnx model {params.nn_model}") + model = OnnxModel(params.nn_model) + + logging.info(f"Loading HLG from {args.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + # Note: We don't use key padding mask for attention during decoding + nnet_output = model.run(features) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index 65be77db1..987c49de6 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -15,6 +15,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This file shows how to use a checkpoint for decoding. + +Usage: + + ./tdnn/pretrained.py \ + --checkpoint ./tdnn/exp/pretrained.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +Note that to generate ./tdnn/exp/pretrained.pt, +you can use ./export.py +""" import argparse import logging @@ -43,7 +58,8 @@ def get_parser(): required=True, help="Path to the checkpoint. " "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + "icefall.checkpoint.save_checkpoint(). " + "You can use ./tdnn/export.py to obtain it.", ) parser.add_argument( @@ -61,8 +77,7 @@ def get_parser(): nargs="+", help="The input sound file(s) to transcribe. " "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + "For example, wav and flac are supported. ", ) return parser @@ -99,14 +114,19 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + # We use only the first channel - ans.append(wave[0]) + ans.append(wave[0].contiguous()) return ans +@torch.no_grad() def main(): parser = get_parser() args = parser.parse_args() @@ -159,8 +179,7 @@ def main(): features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding - with torch.no_grad(): - nnet_output = model(features) + nnet_output = model(features) batch_size = nnet_output.shape[0] supervision_segments = torch.tensor( From a81396b482c799b2ace2cefb79859be827b16f00 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 12 Aug 2023 16:53:59 +0800 Subject: [PATCH 018/113] Use tokens.txt to replace bpe.model (#1162) --- ...n-librispeech-conformer-ctc3-2022-11-28.sh | 10 +- ...h-lstm-transducer-stateless2-2022-09-03.sh | 6 +- ...-pruned-transducer-stateless-2022-03-12.sh | 4 +- ...pruned-transducer-stateless2-2022-04-29.sh | 4 +- ...pruned-transducer-stateless3-2022-04-29.sh | 4 +- ...pruned-transducer-stateless3-2022-05-13.sh | 8 +- ...pruned-transducer-stateless5-2022-05-13.sh | 4 +- ...pruned-transducer-stateless7-2022-11-11.sh | 6 +- ...ed-transducer-stateless7-ctc-2022-12-01.sh | 6 +- ...transducer-stateless7-ctc-bs-2023-01-29.sh | 6 +- ...nsducer-stateless7-streaming-2022-12-29.sh | 6 +- ...pruned-transducer-stateless8-2022-11-14.sh | 6 +- ...pruned-transducer-stateless2-2022-06-26.sh | 4 +- ...speech-transducer-stateless2-2022-04-19.sh | 4 +- ...un-librispeech-zipformer-mmi-2022-12-08.sh | 4 +- .../scripts/run-pre-trained-conformer-ctc.sh | 4 +- ...d-transducer-stateless-librispeech-100h.sh | 4 +- ...d-transducer-stateless-librispeech-960h.sh | 4 +- .../run-pre-trained-transducer-stateless.sh | 4 +- .github/scripts/run-pre-trained-transducer.sh | 2 +- ...enetspeech-pruned-transducer-stateless2.sh | 36 +- .github/scripts/test-ncnn-export.sh | 12 +- .github/scripts/test-onnx-export.sh | 138 ++++++- .../pruned_transducer_stateless7/export.py | 322 +--------------- .../pretrained.py | 349 +----------------- egs/librispeech/ASR/conformer_ctc/export.py | 18 +- .../ASR/conformer_ctc/pretrained.py | 40 +- egs/librispeech/ASR/conformer_ctc2/export.py | 19 +- egs/librispeech/ASR/conformer_ctc3/export.py | 23 +- .../ASR/conformer_ctc3/pretrained.py | 42 ++- .../export.py | 22 +- .../export-for-ncnn.py | 22 +- .../export-onnx.py | 25 +- .../export.py | 22 +- .../onnx_pretrained.py | 2 +- .../ASR/lstm_transducer_stateless/export.py | 25 +- .../lstm_transducer_stateless/pretrained.py | 49 +-- .../export-for-ncnn.py | 23 +- .../export-onnx-zh.py | 2 +- .../lstm_transducer_stateless2/export-onnx.py | 25 +- .../ASR/lstm_transducer_stateless2/export.py | 25 +- .../lstm_transducer_stateless2/pretrained.py | 49 +-- .../ASR/lstm_transducer_stateless3/export.py | 25 +- .../lstm_transducer_stateless3/pretrained.py | 46 ++- .../pruned_stateless_emformer_rnnt2/export.py | 23 +- .../export-onnx.py | 2 +- .../ASR/pruned_transducer_stateless/export.py | 24 +- .../pruned_transducer_stateless/pretrained.py | 49 +-- .../pruned_transducer_stateless2/export.py | 22 +- .../pretrained.py | 49 +-- .../export-onnx.py | 24 +- .../pruned_transducer_stateless3/export.py | 26 +- .../pretrained.py | 51 +-- .../pruned_transducer_stateless4/export.py | 22 +- .../export-onnx-streaming.py | 26 +- .../export-onnx.py | 26 +- .../pruned_transducer_stateless5/export.py | 22 +- .../pretrained.py | 49 +-- .../pruned_transducer_stateless6/export.py | 22 +- .../export-onnx.py | 27 +- .../pruned_transducer_stateless7/export.py | 30 +- .../pretrained.py | 55 +-- .../export.py | 24 +- .../pretrained.py | 51 +-- .../pretrained_ctc.py | 10 +- .../export.py | 24 +- .../export_onnx.py | 26 +- .../pretrained.py | 51 +-- .../pretrained_ctc.py | 10 +- .../export-for-ncnn-zh.py | 21 +- .../export-for-ncnn.py | 22 +- .../export-onnx-zh.py | 25 +- .../export-onnx.py | 24 +- .../export.py | 20 +- .../pretrained.py | 51 +-- .../export-for-ncnn.py | 22 +- .../pruned_transducer_stateless8/export.py | 24 +- .../pretrained.py | 51 +-- egs/librispeech/ASR/transducer/export.py | 22 +- egs/librispeech/ASR/transducer/pretrained.py | 33 +- .../ASR/transducer_stateless/export.py | 22 +- .../ASR/transducer_stateless/pretrained.py | 36 +- .../ASR/transducer_stateless2/export.py | 22 +- .../ASR/transducer_stateless2/pretrained.py | 36 +- .../export.py | 22 +- .../pretrained.py | 36 +- .../ASR/zipformer/export-onnx-streaming.py | 4 +- egs/librispeech/ASR/zipformer/export-onnx.py | 4 +- egs/librispeech/ASR/zipformer/export.py | 25 +- .../ASR/zipformer/jit_pretrained_ctc.py | 18 +- egs/librispeech/ASR/zipformer/onnx_check.py | 1 - .../zipformer/onnx_pretrained-streaming.py | 3 +- .../ASR/zipformer/onnx_pretrained.py | 1 - .../ASR/zipformer/pretrained_ctc.py | 20 +- egs/librispeech/ASR/zipformer_mmi/export.py | 24 +- .../ASR/zipformer_mmi/pretrained.py | 47 +-- .../export-onnx.py | 2 +- .../pretrained.py | 2 +- icefall/utils.py | 20 + 99 files changed, 1243 insertions(+), 1623 deletions(-) mode change 100755 => 120000 egs/aishell/ASR/pruned_transducer_stateless7/export.py mode change 100644 => 120000 egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh index c68ccc954..f6fe8c9b2 100755 --- a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh +++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh @@ -38,7 +38,7 @@ log "Decode with models exported by torch.jit.trace()" for m in ctc-decoding 1best; do ./conformer_ctc3/jit_pretrained.py \ --model-filename $repo/exp/jit_trace.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ --bpe-model $repo/data/lang_bpe_500/bpe.model \ --G $repo/data/lm/G_4_gram.pt \ @@ -53,7 +53,7 @@ log "Export to torchscript model" ./conformer_ctc3/export.py \ --exp-dir $repo/exp \ - --lang-dir $repo/data/lang_bpe_500 \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --jit-trace 1 \ --epoch 99 \ --avg 1 \ @@ -80,9 +80,9 @@ done for m in ctc-decoding 1best; do ./conformer_ctc3/pretrained.py \ --checkpoint $repo/exp/pretrained.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --G $repo/data/lm/G_4_gram.pt \ --method $m \ --sample-rate 16000 \ @@ -93,7 +93,7 @@ done echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then mkdir -p conformer_ctc3/exp ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt ln -s $PWD/$repo/data/lang_bpe_500 data/ diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh index 4cd2c4bec..d547bdd45 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh @@ -31,7 +31,7 @@ log "Test exporting with torch.jit.trace()" ./lstm_transducer_stateless2/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -55,7 +55,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -68,7 +68,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh index 6792c7088..412e3ad56 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index dbf678d72..243b669ed 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -36,7 +36,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -49,7 +49,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index b6d477afe..2d0f80304 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -35,7 +35,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -48,7 +48,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index efa4b53f0..3d5814c48 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -30,14 +30,14 @@ popd log "Export to torchscript model" ./pruned_transducer_stateless3/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 ./pruned_transducer_stateless3/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit-trace 1 @@ -74,7 +74,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh index 511fe0c9e..3d2442d54 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh @@ -32,7 +32,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --num-encoder-layers 18 \ --dim-feedforward 2048 \ --nhead 8 \ @@ -51,7 +51,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh index 2bc179c86..961dde4f4 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -33,7 +33,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -56,7 +56,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -69,7 +69,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh index 192438353..ba7139efb 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -74,7 +74,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh index 7d2853c17..1ecbc4798 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh @@ -36,7 +36,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -72,7 +72,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -85,7 +85,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh index e1e4e1f10..37b192a57 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ --epoch 99 \ --avg 1 \ @@ -81,7 +81,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ @@ -95,7 +95,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh index 5d9485692..4f2bfac24 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh @@ -41,7 +41,7 @@ log "Decode with models exported by torch.jit.script()" log "Export to torchscript model" ./pruned_transducer_stateless8/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model false \ --epoch 99 \ --avg 1 \ @@ -65,7 +65,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -78,7 +78,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh index 77cd59506..5cbdad16d 100755 --- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh +++ b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh @@ -32,7 +32,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --simulate-streaming 1 \ --causal-convolution 1 \ $repo/test_wavs/1089-134686-0001.wav \ @@ -47,7 +47,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --simulate-streaming 1 \ --causal-convolution 1 \ $repo/test_wavs/1089-134686-0001.wav \ diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh index b4aca1b6b..ff77855a2 100755 --- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh +++ b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh index a58b8ec56..c59921055 100755 --- a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh +++ b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./zipformer_mmi/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -61,7 +61,7 @@ for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescor --method $method \ --checkpoint $repo/exp/pretrained.pt \ --lang-dir $repo/data/lang_bpe_500 \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index 125d1f3b1..a4959aa01 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -27,7 +27,7 @@ log "CTC decoding" --method ctc-decoding \ --num-classes 500 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.flac \ $repo/test_wavs/1221-135766-0001.flac \ $repo/test_wavs/1221-135766-0002.flac @@ -38,7 +38,7 @@ log "HLG decoding" --method 1best \ --num-classes 500 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ $repo/test_wavs/1089-134686-0001.flac \ diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh index 89115e88d..7b686328d 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh index 85e2c89e6..a8eeeb514 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh index 41456f11b..2e2360435 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh index 1331c966c..b865f8d13 100755 --- a/.github/scripts/run-pre-trained-transducer.sh +++ b/.github/scripts/run-pre-trained-transducer.sh @@ -27,7 +27,7 @@ log "Beam search decoding" --method beam_search \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh index 90097c752..a3a2d3080 100755 --- a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh +++ b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh @@ -17,7 +17,6 @@ git lfs install git clone $repo_url repo=$(basename $repo_url) - log "Display test files" tree $repo/ ls -lh $repo/test_wavs/*.wav @@ -29,12 +28,11 @@ popd log "Test exporting to ONNX format" -./pruned_transducer_stateless2/export.py \ +./pruned_transducer_stateless2/export-onnx.py \ --exp-dir $repo/exp \ --lang-dir $repo/data/lang_char \ --epoch 99 \ - --avg 1 \ - --onnx 1 + --avg 1 log "Export to torchscript model" @@ -59,19 +57,17 @@ log "Decode with ONNX models" ./pruned_transducer_stateless2/onnx_check.py \ --jit-filename $repo/exp/cpu_jit.pt \ - --onnx-encoder-filename $repo/exp/encoder.onnx \ - --onnx-decoder-filename $repo/exp/decoder.onnx \ - --onnx-joiner-filename $repo/exp/joiner.onnx \ - --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ - --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx + --onnx-encoder-filename $repo/exp/encoder-epoch-10-avg-2.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-10-avg-2.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-10-avg-2.onnx \ + --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ + --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj-epoch-10-avg-2.onnx ./pruned_transducer_stateless2/onnx_pretrained.py \ --tokens $repo/data/lang_char/tokens.txt \ - --encoder-model-filename $repo/exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ $repo/test_wavs/DEV_T0000000000.wav \ $repo/test_wavs/DEV_T0000000001.wav \ $repo/test_wavs/DEV_T0000000002.wav @@ -104,9 +100,9 @@ for sym in 1 2 3; do --lang-dir $repo/data/lang_char \ --decoding-method greedy_search \ --max-sym-per-frame $sym \ - $repo/test_wavs/DEV_T0000000000.wav \ - $repo/test_wavs/DEV_T0000000001.wav \ - $repo/test_wavs/DEV_T0000000002.wav + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav done for method in modified_beam_search beam_search fast_beam_search; do @@ -117,7 +113,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --beam-size 4 \ --checkpoint $repo/exp/epoch-99.pt \ --lang-dir $repo/data/lang_char \ - $repo/test_wavs/DEV_T0000000000.wav \ - $repo/test_wavs/DEV_T0000000001.wav \ - $repo/test_wavs/DEV_T0000000002.wav + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav done diff --git a/.github/scripts/test-ncnn-export.sh b/.github/scripts/test-ncnn-export.sh index ac16131d0..4073c594a 100755 --- a/.github/scripts/test-ncnn-export.sh +++ b/.github/scripts/test-ncnn-export.sh @@ -45,7 +45,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" cd exp @@ -56,11 +55,10 @@ log "Export via torch.jit.trace()" ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ - \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --num-encoder-layers 12 \ --chunk-length 32 \ --cnn-module-kernel 31 \ @@ -91,7 +89,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" cd exp @@ -102,7 +99,7 @@ log "Export via torch.jit.trace()" ./lstm_transducer_stateless2/export-for-ncnn.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 @@ -140,7 +137,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained.pt" cd exp @@ -148,7 +144,7 @@ ln -s pretrained.pt epoch-99.pt popd ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 99 \ @@ -199,7 +195,7 @@ ln -s pretrained.pt epoch-9999.pt popd ./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ - --lang-dir $repo/data/lang_char_bpe \ + --tokens $repo/data/lang_char_bpe/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 9999 \ diff --git a/.github/scripts/test-onnx-export.sh b/.github/scripts/test-onnx-export.sh index 39467c44a..fcfc11fa6 100755 --- a/.github/scripts/test-onnx-export.sh +++ b/.github/scripts/test-onnx-export.sh @@ -10,7 +10,123 @@ log() { cd egs/librispeech/ASR +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) +pushd $repo +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Export via torch.jit.script()" +./zipformer/export.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +log "Test export to ONNX format" +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./zipformer/onnx_check.py \ + --jit-filename $repo/exp/jit_script.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Test export streaming model to ONNX format" +./zipformer/export-onnx-streaming.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 64 + +ls -lh $repo/exp + +log "Run onnx_pretrained-streaming.py" + +./zipformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo + +log "--------------------------------------------------------------------------" log "==========================================================================" repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 @@ -39,7 +155,7 @@ log "Export via torch.jit.trace()" log "Test exporting to ONNX format" ./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -88,7 +204,7 @@ popd log "Export via torch.jit.script()" ./pruned_transducer_stateless3/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 9999 \ --avg 1 \ --exp-dir $repo/exp/ \ @@ -97,7 +213,7 @@ log "Export via torch.jit.script()" log "Test exporting to ONNX format" ./pruned_transducer_stateless3/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 9999 \ --avg 1 \ --exp-dir $repo/exp/ @@ -126,7 +242,6 @@ log "Run onnx_pretrained.py" rm -rf $repo log "--------------------------------------------------------------------------" - log "==========================================================================" repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url @@ -143,7 +258,7 @@ popd log "Export via torch.jit.script()" ./pruned_transducer_stateless5/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -159,7 +274,7 @@ log "Export via torch.jit.script()" log "Test exporting to ONNX format" ./pruned_transducer_stateless5/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -205,7 +320,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained.pt" cd exp @@ -215,7 +329,7 @@ popd log "Export via torch.jit.script()" ./pruned_transducer_stateless7/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -226,7 +340,7 @@ log "Export via torch.jit.script()" log "Test exporting to ONNX format" ./pruned_transducer_stateless7/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -270,7 +384,7 @@ popd log "Test exporting to ONNX format" ./conv_emformer_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -310,7 +424,7 @@ popd log "Export via torch.jit.trace()" ./lstm_transducer_stateless2/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -320,7 +434,7 @@ log "Export via torch.jit.trace()" log "Test exporting to ONNX format" ./lstm_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export.py b/egs/aishell/ASR/pruned_transducer_stateless7/export.py deleted file mode 100755 index 1b0e8d3b9..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/export.py +++ /dev/null @@ -1,321 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Check -https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 20 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `pruned_transducer_stateless7/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless7/decode.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --lang-dir data/lang_char - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21 - # You will find the pre-trained model in icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21exp -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export.py b/egs/aishell/ASR/pruned_transducer_stateless7/export.py new file mode 120000 index 000000000..2713792e6 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/export.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py deleted file mode 100644 index cc54027d6..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py +++ /dev/null @@ -1,348 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 20 \ - --avg 10 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by -./pruned_transducer_stateless7/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.lexicon import Lexicon -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - token_table = lexicon.token_table - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp_tokens = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp_tokens = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py new file mode 120000 index 000000000..068f0f57f --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index fbcbd7b29..f0bb97560 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -23,12 +23,13 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -63,11 +64,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""It contains language related input files such as "lexicon.txt" - """, + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -98,16 +98,16 @@ def get_params() -> AttributeDict: def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) logging.info(params) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + num_classes = num_tokens(token_table) + 1 # +1 for the blank device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 30def9c40..df3e4d819 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -24,7 +24,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from conformer import Conformer @@ -70,11 +69,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -83,10 +80,9 @@ def get_parser(): default="1best", help="""Decoding method. Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a sentence - piece model, i.e., lang_dir/bpe.model, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. + (0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file + to convert tokens to actual words or characters. It needs + neither a lexicon nor an n-gram LM. (1) 1best - Use the best path as decoding output. Only the transformer encoder output is used for decoding. We call it HLG decoding. @@ -297,6 +293,7 @@ def main(): waves = [w.to(device) for w in waves] logging.info("Decoding started") + hyps = [] features = fbank(waves) features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) @@ -313,10 +310,17 @@ def main(): if params.method == "ctc-decoding": logging.info("Use CTC decoding") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.bpe_model) max_token_id = params.num_classes - 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + H = k2.ctc_topo( max_token=max_token_id, modified=params.num_classes > 500, @@ -337,9 +341,9 @@ def main(): best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method in [ "1best", "whole-lattice-rescoring", @@ -408,16 +412,16 @@ def main(): ) best_path = next(iter(best_path_dict.values())) - hyps = get_texts(best_path) word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(" ".join([word_sym_table[i] for i in hyp])) else: raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 7892b03c6..26a95dbfa 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -23,6 +23,7 @@ Usage: ./conformer_ctc2/export.py \ --exp-dir ./conformer_ctc2/exp \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -46,6 +47,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from decode import get_params @@ -56,8 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -123,10 +124,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="The lang dir", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -143,14 +144,14 @@ def get_parser(): def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + num_classes = num_tokens(token_table) + 1 # +1 for the blank device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/librispeech/ASR/conformer_ctc3/export.py b/egs/librispeech/ASR/conformer_ctc3/export.py index c5b95d981..5cb9b4b6d 100755 --- a/egs/librispeech/ASR/conformer_ctc3/export.py +++ b/egs/librispeech/ASR/conformer_ctc3/export.py @@ -25,7 +25,7 @@ Usage: ./conformer_ctc3/export.py \ --exp-dir ./conformer_ctc3/exp \ - --lang-dir data/lang_bpe_500 \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 \ --jit-trace 1 @@ -36,7 +36,7 @@ It will generates the file: `jit_trace.pt`. ./conformer_ctc3/export.py \ --exp-dir ./conformer_ctc3/exp \ - --lang-dir data/lang_bpe_500 \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -62,6 +62,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_ctc_model, get_params @@ -72,8 +73,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -130,10 +130,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", + "--tokens", + type=str, + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -171,9 +171,10 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + num_classes = num_tokens(token_table) + 1 # +1 for the blank params.vocab_size = num_classes if params.streaming_model: diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py index 880945ea0..c37b99cce 100755 --- a/egs/librispeech/ASR/conformer_ctc3/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py @@ -24,7 +24,7 @@ Usage (for non-streaming mode): (1) ctc-decoding ./conformer_ctc3/pretrained.py \ --checkpoint conformer_ctc3/exp/pretrained.pt \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method ctc-decoding \ --sample-rate 16000 \ test_wavs/1089-134686-0001.wav @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from decode import get_decoding_params @@ -114,11 +113,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -127,10 +124,9 @@ def get_parser(): default="1best", help="""Decoding method. Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a sentence - piece model, i.e., lang_dir/bpe.model, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. + (0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file + to convert tokens to actual words or characters. It needs + neither a lexicon nor an n-gram LM. (1) 1best - Use the best path as decoding output. Only the transformer encoder output is used for decoding. We call it HLG decoding. @@ -316,6 +312,7 @@ def main(): waves = [w.to(device) for w in waves] logging.info("Decoding started") + hyps = [] features = fbank(waves) feature_lengths = [f.size(0) for f in features] @@ -348,10 +345,17 @@ def main(): if params.method == "ctc-decoding": logging.info("Use CTC decoding") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.bpe_model) max_token_id = params.num_classes - 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + H = k2.ctc_topo( max_token=max_token_id, modified=False, @@ -372,9 +376,9 @@ def main(): best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method in [ "1best", "nbest-rescoring", @@ -439,16 +443,16 @@ def main(): ) best_path = next(iter(best_path_dict.values())) - hyps = get_texts(best_path) word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(" ".join([word_sym_table[i] for i in hyp])) else: raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 09a3e96b0..67fcc35a4 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./conv_emformer_transducer_stateless/export.py \ --exp-dir ./conv_emformer_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 10 \ --use-averaged-model=True \ @@ -62,7 +62,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model @@ -72,7 +72,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -118,10 +118,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -166,12 +166,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py index 8fbb02f14..85dbd4661 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py @@ -8,7 +8,7 @@ for more details about how to use this file. Usage: ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ --exp-dir ./conv_emformer_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 10 \ --use-averaged-model=True \ @@ -37,7 +37,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -48,7 +48,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -94,10 +94,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -217,12 +217,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py index ad0b45bd9..cfd365207 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py @@ -18,7 +18,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" cd exp @@ -28,7 +27,7 @@ popd 2. Export the model to ONNX ./conv_emformer_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -55,14 +54,14 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder +from emformer import Emformer from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model -from emformer import Emformer from icefall.checkpoint import ( average_checkpoints, @@ -70,7 +69,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -127,10 +126,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -484,12 +483,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index b53426c75..8e5b14903 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./conv_emformer_transducer_stateless2/export.py \ --exp-dir ./conv_emformer_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 10 \ --use-averaged-model=True \ @@ -62,7 +62,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -73,7 +73,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -119,10 +119,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -167,12 +167,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py index db92ac696..5d7e2dfcd 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./conv_emformer_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py index e338342cc..c007220d5 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -26,7 +26,7 @@ Usage: ./lstm_transducer_stateless/export.py \ --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 \ --jit-trace 1 @@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`, ./lstm_transducer_stateless/export.py \ --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 @@ -79,7 +79,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -91,7 +91,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -148,10 +148,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -266,12 +266,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index b3a34a9e3..119fcf1fd 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,6 +78,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -95,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -214,13 +215,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -275,6 +277,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -286,8 +294,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -296,16 +304,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -326,12 +334,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py index 08bfcb204..2b8c92208 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py @@ -29,7 +29,7 @@ popd ./lstm_transducer_stateless2/export-for-ncnn.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -49,7 +49,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -60,7 +60,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -106,10 +106,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -221,12 +221,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py index f068f6a0f..89ced388c 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py @@ -613,7 +613,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py index acaff8540..6b6cb893f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./lstm_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -52,8 +52,8 @@ import logging from pathlib import Path from typing import Dict, Optional, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -68,7 +68,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -125,10 +125,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -437,12 +437,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -607,7 +608,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 0adc68112..5712da25e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -27,7 +27,7 @@ Usage: ./lstm_transducer_stateless2/export.py \ --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 \ --jit-trace 1 @@ -39,7 +39,7 @@ It will generate 3 files: `encoder_jit_trace.pt`, ./lstm_transducer_stateless2/export.py \ --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 @@ -80,7 +80,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -92,7 +92,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -149,10 +149,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -267,12 +267,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index f3f272b9f..5d6d97320 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -69,7 +69,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -82,6 +81,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -98,9 +99,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -217,13 +218,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -278,6 +280,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -289,8 +297,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -299,16 +307,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -329,12 +337,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index a82cad043..21eaa049b 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -26,7 +26,7 @@ Usage: ./lstm_transducer_stateless3/export.py \ --exp-dir ./lstm_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 40 \ --avg 20 \ --jit-trace 1 @@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`, ./lstm_transducer_stateless3/export.py \ --exp-dir ./lstm_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 40 \ --avg 20 @@ -79,7 +79,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -91,7 +91,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -148,10 +148,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to tokens.txt.", ) parser.add_argument( @@ -266,12 +266,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index f49e9c518..29a0d4d1a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./lstm_transducer_stateless3/pretrained.py \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./lstm_transducer_stateless3/pretrained.py \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./lstm_transducer_stateless3/pretrained.py \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -79,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -95,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -214,13 +216,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -275,6 +278,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -286,8 +295,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -296,16 +305,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -326,12 +335,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 3612a2bfd..ec2c9d580 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -22,7 +22,7 @@ Usage: ./prunted_stateless_emformer_rnnt/export.py \ --exp-dir ./prunted_stateless_emformer_rnnt/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -48,7 +48,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model @@ -58,7 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -115,10 +115,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -154,13 +154,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py index a3ebe9d8c..282238c13 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py @@ -508,7 +508,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index a19f9ab9a..4b20e3a2b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless/export.py \ --exp-dir ./pruned_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -87,10 +87,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -135,13 +135,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index 2ed1725b4..02f9f1b03 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,7 +78,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -97,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -237,13 +236,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -314,6 +314,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -325,8 +331,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -335,16 +341,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -365,12 +371,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index 984caf5f2..e02afa892 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -98,10 +98,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -145,12 +145,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 013964720..029f55ba0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,7 +78,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -97,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -238,13 +237,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -315,6 +315,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -326,8 +332,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -336,16 +342,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -366,12 +372,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py index 9645b7801..26dea7e11 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless3/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 9999 \ --avg 1 \ --exp-dir $repo/exp/ @@ -48,8 +48,8 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer @@ -59,7 +59,7 @@ from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import setup_logger +from icefall.utils import num_tokens, setup_logger def get_parser(): @@ -105,10 +105,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -393,12 +393,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -518,7 +520,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index f30c9df6a..925b15646 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 \ --jit 1 @@ -44,7 +44,7 @@ It will also generate 3 other files: `encoder_jit_script.pt`, ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 \ --jit-trace 1 @@ -56,7 +56,7 @@ It will generates 3 files: `encoder_jit_trace.pt`, ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -97,14 +97,14 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -150,10 +150,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -342,12 +342,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 7c3dfc660..abda4e2d4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -247,13 +246,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -324,6 +324,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -335,8 +341,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -345,16 +351,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -375,12 +381,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index 8f33f5b05..08d736f52 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless4/export.py \ --exp-dir ./pruned_transducer_stateless4/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -48,7 +48,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -59,7 +59,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -116,10 +116,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -164,12 +164,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py index 938ff2f16..549fb13c9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless5/export-onnx-streaming.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -58,13 +58,13 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer -from onnxruntime.quantization import QuantType, quantize_dynamic from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -74,7 +74,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -131,10 +131,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -489,12 +489,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -662,7 +664,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py index 20fd8dff8..fff0fcdd5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless5/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -55,13 +55,13 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer -from onnxruntime.quantization import QuantType, quantize_dynamic from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -71,7 +71,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -128,10 +128,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -416,12 +416,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -586,7 +588,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index 54f656859..e5223be26 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -48,7 +48,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -59,7 +59,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -116,10 +116,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -164,12 +164,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index 74a2210c3..304fa8693 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,6 +78,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -95,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -214,13 +215,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -275,6 +277,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -286,8 +294,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -296,16 +304,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -326,12 +334,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index 4d0d8326c..38f48b2ed 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless6/export.py \ --exp-dir ./pruned_transducer_stateless6/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -98,10 +98,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -135,12 +135,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py index d2db92820..11c885f4d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang +# Zengrui Jin) """ This script exports a transducer model from PyTorch to ONNX. @@ -18,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt" cd exp @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -50,8 +50,8 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -66,7 +66,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -123,10 +123,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -411,12 +410,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -581,7 +580,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 3e3160e7e..eb4c4d282 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang +# Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -26,7 +27,7 @@ Usage: ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +46,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -65,7 +66,7 @@ you can do: --avg 1 \ --max-duration 600 \ --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model + --tokens data/lang_bpe_500/tokens.txt \ Check ./pretrained.py for its usage. @@ -86,7 +87,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -98,7 +99,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -155,10 +156,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -198,12 +198,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -292,7 +292,7 @@ def main(): model.to("cpu") model.eval() - if params.jit is True: + if params.jit: convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index d05bafcfb..86c922cda 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -20,7 +21,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +30,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +38,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +47,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +56,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +76,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +88,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens def get_parser(): @@ -106,9 +106,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -225,13 +225,13 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py index c1607699f..51e62d6a8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -197,12 +197,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py index 2f1b1a49f..78e0fa778 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,6 +87,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -104,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -223,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -284,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -295,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -305,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -335,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py index 5d460edb5..904c1deae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py @@ -22,14 +22,14 @@ You can use the following command to get the exported models: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 Usage of this script: (1) ctc-decoding -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --bpe-model data/lang_bpe_500/bpe.model \ --method ctc-decoding \ @@ -38,7 +38,7 @@ Usage of this script: /path/to/bar.wav (2) 1best -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -48,7 +48,7 @@ Usage of this script: /path/to/bar.wav (3) nbest-rescoring -./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./bruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -60,7 +60,7 @@ Usage of this script: (4) whole-lattice-rescoring -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py index 05df8cfff..9f35cf63e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -197,12 +197,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py index 630a7f735..d3033b888 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py @@ -28,7 +28,7 @@ Usage: ./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --onnx 1 @@ -48,7 +48,7 @@ Check `onnx_check.py` for how to use them. (2) Export to ONNX format which can be used in Triton Server ./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --onnx-triton 1 @@ -86,9 +86,10 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn +from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -98,8 +99,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool -from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -156,10 +156,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -728,12 +728,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py index ea0fe9164..5d240cf30 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,6 +87,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -104,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -223,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -284,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -295,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -305,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -335,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py index 412631ba1..914107526 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py @@ -22,14 +22,14 @@ You can use the following command to get the exported models: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 Usage of this script: (1) ctc-decoding -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --bpe-model data/lang_bpe_500/bpe.model \ --method ctc-decoding \ @@ -38,7 +38,7 @@ Usage of this script: /path/to/bar.wav (2) 1best -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -48,7 +48,7 @@ Usage of this script: /path/to/bar.wav (3) nbest-rescoring -./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./bruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -60,7 +60,7 @@ Usage of this script: (4) whole-lattice-rescoring -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py index e196f8b7d..07de57a86 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -66,6 +66,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -76,8 +77,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -123,10 +123,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="The tokens.txt file", ) parser.add_argument( @@ -246,9 +246,14 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py index 4a16a97fb..9a6b31268 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -28,7 +28,7 @@ popd 2. Export to ncnn ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 99 \ @@ -64,7 +64,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -75,7 +75,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -121,10 +121,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -244,12 +244,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py index 04d97808d..8653126de 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -29,7 +29,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ - --lang-dir $repo/data/lang_char_bpe \ + --tokens $repo/data/lang_char_bpe/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -60,6 +60,7 @@ import logging from pathlib import Path from typing import Dict, List, Tuple +import k2 import onnx import torch import torch.nn as nn @@ -76,8 +77,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -134,10 +134,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="The tokens.txt file", ) parser.add_argument( @@ -493,9 +493,14 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -661,7 +666,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py index e71bcaf29..6f84d79b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -27,7 +27,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -48,8 +48,8 @@ import logging from pathlib import Path from typing import Dict, List, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -65,7 +65,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -122,10 +122,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -481,12 +481,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -652,7 +654,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py index c191b5bcc..59a7eb589 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -139,8 +139,8 @@ import argparse import logging from pathlib import Path +import k2 import onnxruntime -import sentencepiece as spm import torch import torch.nn as nn from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner @@ -154,7 +154,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -211,10 +211,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -675,12 +675,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py index fb77fdd42..bc42e8d05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -225,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py index 4a16a97fb..9a6b31268 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py @@ -28,7 +28,7 @@ popd 2. Export to ncnn ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 99 \ @@ -64,7 +64,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -75,7 +75,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -121,10 +121,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -244,12 +244,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index d4a228b47..d9697680b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -98,7 +98,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -155,10 +155,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -198,12 +198,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 486d9d74e..64b38c9d5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -225,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py index 6db0272f0..3b9e4a5dc 100755 --- a/egs/librispeech/ASR/transducer/export.py +++ b/egs/librispeech/ASR/transducer/export.py @@ -22,7 +22,7 @@ Usage: ./transducer/export.py \ --exp-dir ./transducer/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 34 \ --avg 11 @@ -46,7 +46,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from conformer import Conformer from decoder import Decoder @@ -55,7 +55,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -90,10 +90,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -191,12 +191,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index 511610245..c2413f5de 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -19,7 +19,7 @@ Usage: ./transducer/pretrained.py \ --checkpoint ./transducer/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav \ @@ -36,8 +36,8 @@ import logging import math from typing import List +import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import beam_search, greedy_search @@ -48,7 +48,7 @@ from model import Transducer from torch.nn.utils.rnn import pad_sequence from icefall.env import get_env_info -from icefall.utils import AttributeDict +from icefall.utils import AttributeDict, num_tokens def get_parser(): @@ -66,11 +66,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="Path to tokens.txt.", ) parser.add_argument( @@ -204,12 +202,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -257,6 +257,12 @@ def main(): x=features, x_lens=feature_lengths ) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + num_waves = encoder_out.size(0) hyps = [] for i in range(num_waves): @@ -272,12 +278,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py index 89359f1a4..c397eb171 100755 --- a/egs/librispeech/ASR/transducer_stateless/export.py +++ b/egs/librispeech/ASR/transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless/export.py \ --exp-dir ./transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -46,7 +46,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -56,7 +56,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -91,10 +91,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -191,12 +191,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index 915a6069d..5898dd0f5 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py index d33d02642..f4b6f5554 100755 --- a/egs/librispeech/ASR/transducer_stateless2/export.py +++ b/egs/librispeech/ASR/transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless2/export.py \ --exp-dir ./transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -46,12 +46,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -86,10 +86,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -123,12 +123,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py index 0738f30c0..b69b347ef 100755 --- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py index 3735ef452..6d31dfe34 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless_multi_datasets/export.py \ --exp-dir ./transducer_stateless_multi_datasets/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,7 +47,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -57,7 +57,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -92,10 +92,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -192,12 +192,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py index 8c7726367..4f29d6f1f 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 3eb06f68c..a951aeef3 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp @@ -74,7 +73,6 @@ import onnx import torch import torch.nn as nn from decoder import Decoder -from export import num_tokens from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -86,7 +84,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index 724fdd2a6..e0d664009 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp @@ -71,7 +70,6 @@ import onnx import torch import torch.nn as nn from decoder import Decoder -from export import num_tokens from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -83,7 +81,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import make_pad_mask, str2bool +from icefall.utils import make_pad_mask, num_tokens, str2bool def get_parser(): diff --git a/egs/librispeech/ASR/zipformer/export.py b/egs/librispeech/ASR/zipformer/export.py index 4a48d5bad..2b8d1aaf3 100755 --- a/egs/librispeech/ASR/zipformer/export.py +++ b/egs/librispeech/ASR/zipformer/export.py @@ -160,7 +160,6 @@ with the following commands: import argparse import logging -import re from pathlib import Path from typing import List, Tuple @@ -176,27 +175,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import make_pad_mask, str2bool - - -def num_tokens( - token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") -) -> int: - """Return the number of tokens excluding those from - disambiguation symbols. - - Caution: - 0 is not a token ID so it is excluded from the return value. - """ - symbols = token_table.symbols - ans = [] - for s in symbols: - if not disambig_pattern.match(s): - ans.append(token_table[s]) - num_tokens = len(ans) - if 0 in ans: - num_tokens -= 1 - return num_tokens +from icefall.utils import make_pad_mask, num_tokens, str2bool def get_parser(): @@ -487,6 +466,8 @@ def main(): device=device, ) ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: assert params.avg > 0, params.avg start = params.epoch - params.avg diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py index 904d8cd76..660a4bfc6 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py @@ -410,10 +410,20 @@ def main(): raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" + if params.method == "ctc-decoding": + for filename, hyp in zip(params.sound_files, hyps): + words = "".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/zipformer/onnx_check.py b/egs/librispeech/ASR/zipformer/onnx_check.py index b38b875d0..93bd3a211 100755 --- a/egs/librispeech/ASR/zipformer/onnx_check.py +++ b/egs/librispeech/ASR/zipformer/onnx_check.py @@ -33,7 +33,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index 2ce4506a8..500b2cd09 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained.pt" cd exp @@ -29,7 +28,7 @@ popd 2. Export the model to ONNX ./zipformer/export-onnx-streaming.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained.py b/egs/librispeech/ASR/zipformer/onnx_pretrained.py index e8a521460..032b07721 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained.py @@ -31,7 +31,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py index be239e9c3..9dff2e6fc 100755 --- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py @@ -274,7 +274,7 @@ def main(): params.update(vars(args)) token_table = k2.SymbolTable.from_file(params.tokens) - params.vocab_size = num_tokens(token_table) + params.vocab_size = num_tokens(token_table) + 1 # +1 for blank params.blank_id = token_table[""] assert params.blank_id == 0 @@ -429,10 +429,20 @@ def main(): raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" + if params.method == "ctc-decoding": + for filename, hyp in zip(params.sound_files, hyps): + words = "".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/zipformer_mmi/export.py b/egs/librispeech/ASR/zipformer_mmi/export.py index 0af7bd367..1aec56420 100755 --- a/egs/librispeech/ASR/zipformer_mmi/export.py +++ b/egs/librispeech/ASR/zipformer_mmi/export.py @@ -26,7 +26,7 @@ Usage: ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_ctc_model, get_params @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -190,12 +190,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py index 0e7fd0daf..3ba4da5dd 100755 --- a/egs/librispeech/ASR/zipformer_mmi/pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py @@ -21,7 +21,7 @@ You can generate the checkpoint with the following command: ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -30,14 +30,14 @@ Usage of this script: (1) 1best ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method 1best \ /path/to/foo.wav \ /path/to/bar.wav (2) nbest ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest \ /path/to/foo.wav \ @@ -45,7 +45,7 @@ Usage of this script: (3) nbest-rescoring-LG ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-LG \ /path/to/foo.wav \ @@ -53,7 +53,7 @@ Usage of this script: (4) nbest-rescoring-3-gram ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-3-gram \ /path/to/foo.wav \ @@ -61,7 +61,7 @@ Usage of this script: (5) nbest-rescoring-4-gram ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-4-gram \ /path/to/foo.wav \ @@ -83,7 +83,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from decode import get_decoding_params @@ -97,7 +96,7 @@ from icefall.decode import ( one_best_decoding, ) from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import get_texts +from icefall.utils import get_texts, num_tokens def get_parser(): @@ -115,9 +114,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -247,13 +246,14 @@ def main(): params.update(get_decoding_params()) params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -298,8 +298,6 @@ def main(): features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) mmi_graph_compiler = MmiTrainingGraphCompiler( params.lang_dir, uniq_filename="lexicon.txt", @@ -313,6 +311,12 @@ def main(): if not hasattr(HP, "lm_scores"): HP.lm_scores = HP.scores.clone() + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + method = params.method assert method in ( "1best", @@ -390,14 +394,11 @@ def main(): # # token_ids is a lit-of-list of IDs token_ids = get_texts(best_path) - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] + hyps = [token_ids_to_words(ids) for ids in token_ids] + s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py index fad66986b..760fad974 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py @@ -498,7 +498,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py index bc499f3dd..c3d67ad92 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -320,7 +320,7 @@ def main(): s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) + words = "".join(hyp) s += f"{filename}:\n{words}\n\n" logging.info(s) diff --git a/icefall/utils.py b/icefall/utils.py index 0feff9dc8..b01cd2770 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -2060,3 +2060,23 @@ def symlink_or_copy(exp_dir: Path, src: str, dst: str): except OSError: copyfile(src=exp_dir / src, dst=exp_dir / dst) os.close(dir_fd) + + +def num_tokens( + token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") +) -> int: + """Return the number of tokens excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + symbols = token_table.symbols + ans = [] + for s in symbols: + if not disambig_pattern.match(s): + ans.append(token_table[s]) + num_tokens = len(ans) + if 0 in ans: + num_tokens -= 1 + return num_tokens From dfccadc6b6551696e2dfff787f3ec102e346d4cd Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 12 Aug 2023 16:59:06 +0800 Subject: [PATCH 019/113] Fix a typo in export_onnx.py for yesno (#1213) --- egs/yesno/ASR/tdnn/export_onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/yesno/ASR/tdnn/export_onnx.py b/egs/yesno/ASR/tdnn/export_onnx.py index 9b2a56d59..2436ca81b 100755 --- a/egs/yesno/ASR/tdnn/export_onnx.py +++ b/egs/yesno/ASR/tdnn/export_onnx.py @@ -126,7 +126,7 @@ def main(): logging.info(f"Saved to {onnx_filename}") meta_data = { - "model_type": "tdnn_lstm", + "model_type": "tdnn", "version": "1", "model_author": "k2-fsa", "comment": "non-streaming tdnn for the yesno recipe", From b0e8a40c8932d82d356b8a2ad4948331eae9749e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sat, 12 Aug 2023 21:50:59 -0400 Subject: [PATCH 020/113] Speed up yesno training to finish in ~10s on CPU (#1215) --- egs/yesno/ASR/tdnn/asr_datamodule.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index 3c1682fa1..ada8c1a6c 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -209,7 +209,7 @@ class YesNoAsrDataModule(DataModule): sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, - persistent_workers=False, + persistent_workers=True, ) return train_dl @@ -236,6 +236,7 @@ class YesNoAsrDataModule(DataModule): batch_size=None, sampler=sampler, num_workers=self.args.num_workers, + persistent_workers=True, ) return test_dl From 3b5645f5944393121e52739d5b9d5ef43a7e7a0f Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 13 Aug 2023 12:37:08 +0800 Subject: [PATCH 021/113] doc updated (#1214) --- docs/source/model-export/export-model-state-dict.rst | 4 ++-- docs/source/model-export/export-ncnn-conv-emformer.rst | 3 +-- docs/source/model-export/export-ncnn-lstm.rst | 2 +- docs/source/model-export/export-ncnn-zipformer.rst | 3 +-- docs/source/model-export/export-onnx.rst | 2 +- docs/source/model-export/export-with-torch-jit-script.rst | 2 +- docs/source/model-export/export-with-torch-jit-trace.rst | 2 +- 7 files changed, 8 insertions(+), 10 deletions(-) diff --git a/docs/source/model-export/export-model-state-dict.rst b/docs/source/model-export/export-model-state-dict.rst index c3bbd5708..5596bb7a6 100644 --- a/docs/source/model-export/export-model-state-dict.rst +++ b/docs/source/model-export/export-model-state-dict.rst @@ -41,7 +41,7 @@ as an example. ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -78,7 +78,7 @@ In each recipe, there is also a file ``pretrained.py``, which can use ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model \ + --tokens ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ --method greedy_search \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav \ diff --git a/docs/source/model-export/export-ncnn-conv-emformer.rst b/docs/source/model-export/export-ncnn-conv-emformer.rst index 12b370143..4f5535d83 100644 --- a/docs/source/model-export/export-ncnn-conv-emformer.rst +++ b/docs/source/model-export/export-ncnn-conv-emformer.rst @@ -153,11 +153,10 @@ Next, we use the following code to export our model: ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ --exp-dir $dir/exp \ - --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --tokens $dir/data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 1 \ --use-averaged-model 0 \ - \ --num-encoder-layers 12 \ --chunk-length 32 \ --cnn-module-kernel 31 \ diff --git a/docs/source/model-export/export-ncnn-lstm.rst b/docs/source/model-export/export-ncnn-lstm.rst index 8e6dc7466..310c3d8e4 100644 --- a/docs/source/model-export/export-ncnn-lstm.rst +++ b/docs/source/model-export/export-ncnn-lstm.rst @@ -73,7 +73,7 @@ Next, we use the following code to export our model: ./lstm_transducer_stateless2/export-for-ncnn.py \ --exp-dir $dir/exp \ - --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --tokens $dir/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ diff --git a/docs/source/model-export/export-ncnn-zipformer.rst b/docs/source/model-export/export-ncnn-zipformer.rst index 8440d26b7..a5845b0e4 100644 --- a/docs/source/model-export/export-ncnn-zipformer.rst +++ b/docs/source/model-export/export-ncnn-zipformer.rst @@ -72,12 +72,11 @@ Next, we use the following code to export our model: dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --tokens $dir/data/lang_bpe_500/tokens.txt \ --exp-dir $dir/exp \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ - \ --decode-chunk-len 32 \ --num-left-chunks 4 \ --num-encoder-layers "2,4,3,2,4" \ diff --git a/docs/source/model-export/export-onnx.rst b/docs/source/model-export/export-onnx.rst index fb952abb7..d95f2acfe 100644 --- a/docs/source/model-export/export-onnx.rst +++ b/docs/source/model-export/export-onnx.rst @@ -71,7 +71,7 @@ Export the model to ONNX .. code-block:: bash ./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ diff --git a/docs/source/model-export/export-with-torch-jit-script.rst b/docs/source/model-export/export-with-torch-jit-script.rst index efd7dc2e1..31c8f0bf5 100644 --- a/docs/source/model-export/export-with-torch-jit-script.rst +++ b/docs/source/model-export/export-with-torch-jit-script.rst @@ -32,7 +32,7 @@ as an example in the following. ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch $epoch \ --avg $avg \ --jit 1 diff --git a/docs/source/model-export/export-with-torch-jit-trace.rst b/docs/source/model-export/export-with-torch-jit-trace.rst index 506459909..be7876ab5 100644 --- a/docs/source/model-export/export-with-torch-jit-trace.rst +++ b/docs/source/model-export/export-with-torch-jit-trace.rst @@ -33,7 +33,7 @@ as an example in the following. ./lstm_transducer_stateless2/export.py \ --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --iter $iter \ --avg $avg \ --jit-trace 1 From 9a47c08d085f00b63ce2d7c6d0fee16812691ed7 Mon Sep 17 00:00:00 2001 From: Erwan Zerhouni <61225408+ezerhouni@users.noreply.github.com> Date: Mon, 14 Aug 2023 16:10:50 +0200 Subject: [PATCH 022/113] Update padding modified beam search (#1217) --- .../beam_search.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index fd59d4b7f..97e259b40 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1008,7 +1008,7 @@ def modified_beam_search( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), context_state=None if context_graph is None else context_graph.root, timestamp=[], @@ -1217,7 +1217,7 @@ def modified_beam_search_lm_rescore( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), timestamp=[], ) @@ -1417,7 +1417,7 @@ def modified_beam_search_lm_rescore_LODR( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), timestamp=[], ) @@ -1617,7 +1617,7 @@ def _deprecated_modified_beam_search( B = HypothesisList() B.add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), timestamp=[], ) @@ -1753,7 +1753,11 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) + B.add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], log_prob=0.0, timestamp=[] + ) + ) max_sym_per_utt = 20000 @@ -2265,7 +2269,7 @@ def modified_beam_search_ngram_rescoring( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), state_cost=NgramLmStateCost(ngram_lm), ) @@ -2446,7 +2450,7 @@ def modified_beam_search_LODR( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), state=init_states, # state of the NN LM lm_score=init_score.reshape(-1), @@ -2709,7 +2713,7 @@ def modified_beam_search_lm_shallow_fusion( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), state=init_states, lm_score=init_score.reshape(-1), From fc2df07841b3edbd7bffddfcc2e016515aa75247 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 16 Aug 2023 22:32:41 +0800 Subject: [PATCH 023/113] Add icefall tutorials for dummies. (#1220) --- docs/source/conf.py | 3 + docs/source/for-dummies/data-preparation.rst | 180 ++++++++++ docs/source/for-dummies/decoding.rst | 39 +++ docs/source/for-dummies/environment-setup.rst | 121 +++++++ docs/source/for-dummies/index.rst | 34 ++ docs/source/for-dummies/model-export.rst | 310 ++++++++++++++++++ docs/source/for-dummies/training.rst | 39 +++ docs/source/index.rst | 1 + egs/yesno/ASR/tdnn/onnx_pretrained.py | 1 + 9 files changed, 728 insertions(+) create mode 100644 docs/source/for-dummies/data-preparation.rst create mode 100644 docs/source/for-dummies/decoding.rst create mode 100644 docs/source/for-dummies/environment-setup.rst create mode 100644 docs/source/for-dummies/index.rst create mode 100644 docs/source/for-dummies/model-export.rst create mode 100644 docs/source/for-dummies/training.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index bf231e3c1..5a534e126 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -95,4 +95,7 @@ rst_epilog = """ .. _k2: https://github.com/k2-fsa/k2 .. _lhotse: https://github.com/lhotse-speech/lhotse .. _yesno: https://www.openslr.org/1/ +.. _Next-gen Kaldi: https://github.com/k2-fsa +.. _Kaldi: https://github.com/kaldi-asr/kaldi +.. _lilcom: https://github.com/danpovey/lilcom """ diff --git a/docs/source/for-dummies/data-preparation.rst b/docs/source/for-dummies/data-preparation.rst new file mode 100644 index 000000000..f03d44e79 --- /dev/null +++ b/docs/source/for-dummies/data-preparation.rst @@ -0,0 +1,180 @@ +.. _dummies_tutorial_data_preparation: + +Data Preparation +================ + +After :ref:`dummies_tutorial_environment_setup`, we can start preparing the +data for training and decoding. + +The first step is to prepare the data for training. We have already provided +`prepare.sh `_ +that would prepare everything required for training. + +.. code-block:: + + cd /tmp/icefall + export PYTHONPATH=/tmp/icefall:$PYTHONPATH + cd egs/yesno/ASR + + ./prepare.sh + +Note that in each recipe from `icefall`_, there exists a file ``prepare.sh``, +which you should run before you run anything else. + +That is all you need for data preparation. + +For the more curious +-------------------- + +If you are wondering how to prepare your own dataset, please refer to the following +URLs for more details: + + - ``_ + + It contains recipes for a variety of dataset. If you want to add your own + dataset, please read recipes in this folder first. + + - ``_ + + The `yesno`_ recipe in `lhotse`_. + +If you already have a `Kaldi`_ dataset directory, which contains files like +``wav.scp``, ``feats.scp``, then you can refer to ``_. + +A quick look to the generated files +----------------------------------- + +``./prepare.sh`` puts generated files into two directories: + + - ``download`` + - ``data`` + +download +^^^^^^^^ + +The ``download`` directory contains downloaded dataset files: + +.. code-block:: bas + + tree -L 1 ./download/ + + ./download/ + |-- waves_yesno + `-- waves_yesno.tar.gz + +.. hint:: + + Please refer to ``_ + for how the data is downloaded and extracted. + +data +^^^^ + +.. code-block:: bash + + tree ./data/ + + ./data/ + |-- fbank + | |-- yesno_cuts_test.jsonl.gz + | |-- yesno_cuts_train.jsonl.gz + | |-- yesno_feats_test.lca + | `-- yesno_feats_train.lca + |-- lang_phone + | |-- HLG.pt + | |-- L.pt + | |-- L_disambig.pt + | |-- Linv.pt + | |-- lexicon.txt + | |-- lexicon_disambig.txt + | |-- tokens.txt + | `-- words.txt + |-- lm + | |-- G.arpa + | `-- G.fst.txt + `-- manifests + |-- yesno_recordings_test.jsonl.gz + |-- yesno_recordings_train.jsonl.gz + |-- yesno_supervisions_test.jsonl.gz + `-- yesno_supervisions_train.jsonl.gz + + 4 directories, 18 files + +**data/manifests**: + + This directory contains manifests. They are used to generate files in + ``data/fbank``. + + To give you an idea of what it contains, we examine the first few lines of + the manifests related to the ``train`` dataset. + + .. code-block:: bash + + cd data/manifests + gunzip -c yesno_recordings_train.jsonl.gz | head -n 3 + + The output is given below: + + .. code-block:: bash + + {"id": "0_0_0_0_1_1_1_1", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_0_1_1_1_1.wav"}], "sampling_rate": 8000, "num_samples": 50800, "duration": 6.35, "channel_ids": [0]} + {"id": "0_0_0_1_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_1_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48880, "duration": 6.11, "channel_ids": [0]} + {"id": "0_0_1_0_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_1_0_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48160, "duration": 6.02, "channel_ids": [0]} + + Please refer to ``_ + for the meaning of each field per line. + + .. code-block:: bash + + gunzip -c yesno_supervisions_train.jsonl.gz | head -n 3 + + The output is given below: + + .. code-block:: bash + + {"id": "0_0_0_0_1_1_1_1", "recording_id": "0_0_0_0_1_1_1_1", "start": 0.0, "duration": 6.35, "channel": 0, "text": "NO NO NO NO YES YES YES YES", "language": "Hebrew"} + {"id": "0_0_0_1_0_1_1_0", "recording_id": "0_0_0_1_0_1_1_0", "start": 0.0, "duration": 6.11, "channel": 0, "text": "NO NO NO YES NO YES YES NO", "language": "Hebrew"} + {"id": "0_0_1_0_0_1_1_0", "recording_id": "0_0_1_0_0_1_1_0", "start": 0.0, "duration": 6.02, "channel": 0, "text": "NO NO YES NO NO YES YES NO", "language": "Hebrew"} + + Please refer to ``_ + for the meaning of each field per line. + +**data/fbank**: + + This directory contains everything from ``data/manifests``. Furthermore, it also contains features + for training. + + ``data/fbank/yesno_feats_train.lca`` contains the features for the train dataset. + Features are compressed using `lilcom`_. + + ``data/fbank/yesno_cuts_train.jsonl.gz`` stores the `CutSet `_, + which stores `RecordingSet `_, + `SupervisionSet `_, + and `FeatureSet `_. + + To give you an idea about what it looks like, we can run the following command: + + .. code-block:: bash + + cd data/fbank + + gunzip -c yesno_cuts_train.jsonl.gz | head -n 3 + + The output is given below: + + .. code-block:: bash + + {"id": "0_0_0_0_1_1_1_1-0", "start": 0, "duration": 6.35, "channel": 0, "supervisions": [{"id": "0_0_0_0_1_1_1_1", "recording_id": "0_0_0_0_1_1_1_1", "start": 0.0, "duration": 6.35, "channel": 0, "text": "NO NO NO NO YES YES YES YES", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 635, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.35, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "0,13000,3570", "channels": 0}, "recording": {"id": "0_0_0_0_1_1_1_1", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_0_1_1_1_1.wav"}], "sampling_rate": 8000, "num_samples": 50800, "duration": 6.35, "channel_ids": [0]}, "type": "MonoCut"} + {"id": "0_0_0_1_0_1_1_0-1", "start": 0, "duration": 6.11, "channel": 0, "supervisions": [{"id": "0_0_0_1_0_1_1_0", "recording_id": "0_0_0_1_0_1_1_0", "start": 0.0, "duration": 6.11, "channel": 0, "text": "NO NO NO YES NO YES YES NO", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 611, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.11, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "16570,12964,2929", "channels": 0}, "recording": {"id": "0_0_0_1_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_1_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48880, "duration": 6.11, "channel_ids": [0]}, "type": "MonoCut"} + {"id": "0_0_1_0_0_1_1_0-2", "start": 0, "duration": 6.02, "channel": 0, "supervisions": [{"id": "0_0_1_0_0_1_1_0", "recording_id": "0_0_1_0_0_1_1_0", "start": 0.0, "duration": 6.02, "channel": 0, "text": "NO NO YES NO NO YES YES NO", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 602, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.02, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "32463,12936,2696", "channels": 0}, "recording": {"id": "0_0_1_0_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_1_0_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48160, "duration": 6.02, "channel_ids": [0]}, "type": "MonoCut"} + + Note that ``yesno_cuts_train.jsonl.gz`` only stores the information about how to read the features. + The actual features are stored separately in ``data/fbank/yesno_feats_train.lca``. + +**data/lang**: + + This directory contains the lexicon. + +**data/lm**: + + This directory contains language models. diff --git a/docs/source/for-dummies/decoding.rst b/docs/source/for-dummies/decoding.rst new file mode 100644 index 000000000..3e48e8bfd --- /dev/null +++ b/docs/source/for-dummies/decoding.rst @@ -0,0 +1,39 @@ +.. _dummies_tutorial_decoding: + +Decoding +======== + +After :ref:`dummies_tutorial_training`, we can start decoding. + +The command to start the decoding is quite simple: + +.. code-block:: bash + + cd /tmp/icefall + export PYTHONPATH=/tmp/icefall:$PYTHONPATH + cd egs/yesno/ASR + + # We use CPU for decoding by setting the following environment variable + export CUDA_VISIBLE_DEVICES="" + + ./tdnn/decode.py + +The output logs are given below: + +.. literalinclude:: ./code/decoding-yesno.txt + +For the more curious +-------------------- + +.. code-block:: bash + + ./tdnn/decode.py --help + +will print the usage information about ``./tdnn/decode.py``. For instance, you +can specify: + + - ``--epoch`` to use which checkpoint for decoding + - ``--avg`` to select how many checkpoints to use for model averaging + +You usually try different combinations of ``--epoch`` and ``--avg`` and select +one that leads to the lowest WER (`Word Error Rate `_). diff --git a/docs/source/for-dummies/environment-setup.rst b/docs/source/for-dummies/environment-setup.rst new file mode 100644 index 000000000..0cb8ecc1d --- /dev/null +++ b/docs/source/for-dummies/environment-setup.rst @@ -0,0 +1,121 @@ +.. _dummies_tutorial_environment_setup: + +Environment setup +================= + +We will create an environment for `Next-gen Kaldi`_ that runs on ``CPU`` +in this tutorial. + +.. note:: + + Since the `yesno`_ dataset used in this tutorial is very tiny, training on + ``CPU`` works very well for it. + + If your dataset is very large, e.g., hundreds or thousands of hours of + training data, please follow :ref:`install icefall` to install `icefall`_ + that works with ``GPU``. + + +Create a virtual environment +---------------------------- + +.. code-block:: bash + + virtualenv -p python3 /tmp/icefall_env + +The above command creates a virtual environment in the directory ``/tmp/icefall_env``. +You can select any directory you want. + +The output of the above command is given below: + +.. code-block:: bash + + Already using interpreter /usr/bin/python3 + Using base prefix '/usr' + New python executable in /tmp/icefall_env/bin/python3 + Also creating executable in /tmp/icefall_env/bin/python + Installing setuptools, pkg_resources, pip, wheel...done. + +Now we can activate the environment using: + +.. code-block:: bash + + source /tmp/icefall_env/bin/activate + +Install dependencies +-------------------- + +.. warning:: + + Remeber to activate your virtual environment before you continue! + +After activating the virtual environment, we can use the following command +to install dependencies of `icefall`_: + +.. hint:: + + Remeber that we will run this tutorial on ``CPU``, so we install + dependencies required only by running on ``CPU``. + +.. code-block:: bash + + # Caution: Installation order matters! + + # We use torch 2.0.0 and torchaduio 2.0.0 in this tutorial. + # Other versions should also work. + + pip install torch==2.0.0+cpu torchaudio==2.0.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + + # If you are using macOS or Windows, please use the following command to install torch and torchaudio + # pip install torch==2.0.0 torchaudio==2.0.0 -f https://download.pytorch.org/whl/torch_stable.html + + # Now install k2 + # Please refer to https://k2-fsa.github.io/k2/installation/from_wheels.html#linux-cpu-example + + pip install k2==1.24.3.dev20230726+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html + + # Install the latest version of lhotse + + pip install git+https://github.com/lhotse-speech/lhotse + + +Install icefall +--------------- + +We will put the source code of `icefall`_ into the directory ``/tmp`` +You can select any directory you want. + +.. code-block:: bash + + cd /tmp + git clone https://github.com/k2-fsa/icefall + cd icefall + pip install -r ./requirements.txt + +.. code-block:: bash + + # Anytime we want to use icefall, we have to set the following + # environment variable + + export PYTHONPATH=/tmp/icefall:$PYTHONPATH + +.. hint:: + + If you get the following error during this tutorial: + + .. code-block:: bash + + ModuleNotFoundError: No module named 'icefall' + + please set the above environment variable to fix it. + + +Congratulations! You have installed `icefall`_ successfully. + +For the more curious +-------------------- + +`icefall`_ contains a collection of Python scripts and you don't need to +use ``python3 setup.py install`` or ``pip install icefall`` to install it. +All you need to do is to download the code and set the environment variable +``PYTHONPATH``. diff --git a/docs/source/for-dummies/index.rst b/docs/source/for-dummies/index.rst new file mode 100644 index 000000000..7c0a3d8ee --- /dev/null +++ b/docs/source/for-dummies/index.rst @@ -0,0 +1,34 @@ +Icefall for dummies tutorial +============================ + +This tutorial walks you step by step about how to create a simple +ASR (`Automatic Speech Recognition `_) +system with `Next-gen Kaldi`_. + +We use the `yesno`_ dataset for demonstration. We select it out of two reasons: + + - It is quite tiny, containing only about 12 minutes of data + - The training can be finished within 20 seconds on ``CPU``. + +That also means you don't need a ``GPU`` to run this tutorial. + +Let's get started! + +Please follow items below **sequentially**. + +.. note:: + + The :ref:`dummies_tutorial_data_preparation` runs only on Linux and on macOS. + All other parts run on Linux, macOS, and Windows. + + Help from the community is appreciated to port the :ref:`dummies_tutorial_data_preparation` + to Windows. + +.. toctree:: + :maxdepth: 2 + + ./environment-setup.rst + ./data-preparation.rst + ./training.rst + ./decoding.rst + ./model-export.rst diff --git a/docs/source/for-dummies/model-export.rst b/docs/source/for-dummies/model-export.rst new file mode 100644 index 000000000..079ebc712 --- /dev/null +++ b/docs/source/for-dummies/model-export.rst @@ -0,0 +1,310 @@ +Model Export +============ + +There are three ways to export a pre-trained model. + + - Export the model parameters via `model.state_dict() `_ + - Export via `torchscript `_: either `torch.jit.script() `_ or `torch.jit.trace() `_ + - Export to `ONNX`_ via `torch.onnx.export() `_ + +Each method is explained below in detail. + +Export the model parameters via model.state_dict() +--------------------------------------------------- + +The command for this kind of export is + +.. code-block:: bash + + cd /tmp/icefall + export PYTHONPATH=/tmp/icefall:$PYTHONPATH + cd egs/yesno/ASR + + # assume that "--epoch 14 --avg 2" produces the lowest WER. + + ./tdnn/export.py --epoch 14 --avg 2 + +The output logs are given below: + +.. code-block:: bash + + 2023-08-16 20:42:03,912 INFO [export.py:76] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'jit': False} + 2023-08-16 20:42:03,913 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt + 2023-08-16 20:42:03,950 INFO [export.py:93] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt'] + 2023-08-16 20:42:03,971 INFO [export.py:106] Not using torch.jit.script + 2023-08-16 20:42:03,974 INFO [export.py:111] Saved to tdnn/exp/pretrained.pt + +We can see from the logs that the exported model is saved to the file ``tdnn/exp/pretrained.pt``. + +To give you an idea of what ``tdnn/exp/pretrained.pt`` contains, we can use the following command: + +.. code-block:: python3 + + >>> import torch + >>> m = torch.load("tdnn/exp/pretrained.pt") + >>> list(m.keys()) + ['model'] + >>> list(m["model"].keys()) + ['tdnn.0.weight', 'tdnn.0.bias', 'tdnn.2.running_mean', 'tdnn.2.running_var', 'tdnn.2.num_batches_tracked', 'tdnn.3.weight', 'tdnn.3.bias', 'tdnn.5.running_mean', 'tdnn.5.running_var', 'tdnn.5.num_batches_tracked', 'tdnn.6.weight', 'tdnn.6.bias', 'tdnn.8.running_mean', 'tdnn.8.running_var', 'tdnn.8.num_batches_tracked', 'output_linear.weight', 'output_linear.bias'] + +We can use ``tdnn/exp/pretrained.pt`` in the following way with ``./tdnn/decode.py``: + +.. code-block:: bash + + cd tdnn/exp + ln -s pretrained.pt epoch-99.pt + cd ../.. + + ./tdnn/decode.py --epoch 99 --avg 1 + +The output logs of the above command are given below: + +.. code-block:: bash + + 2023-08-16 20:45:48,089 INFO [decode.py:262] Decoding started + 2023-08-16 20:45:48,090 INFO [decode.py:263] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 99, 'avg': 1, 'export': False, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': False, 'k2-git-sha1': 'ad79f1c699c684de9785ed6ca5edb805a41f78c3', 'k2-git-date': 'Wed Jul 26 09:30:42 2023', 'lhotse-version': '1.16.0.dev+git.aa073f6.clean', 'torch-version': '2.0.0', 'torch-cuda-available': False, 'torch-cuda-version': None, 'python-version': '3.1', 'icefall-git-branch': 'master', 'icefall-git-sha1': '9a47c08-clean', 'icefall-git-date': 'Mon Aug 14 22:10:50 2023', 'icefall-path': '/private/tmp/icefall', 'k2-path': '/private/tmp/icefall_env/lib/python3.11/site-packages/k2/__init__.py', 'lhotse-path': '/private/tmp/icefall_env/lib/python3.11/site-packages/lhotse/__init__.py', 'hostname': 'fangjuns-MacBook-Pro.local', 'IP address': '127.0.0.1'}} + 2023-08-16 20:45:48,092 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt + 2023-08-16 20:45:48,103 INFO [decode.py:272] device: cpu + 2023-08-16 20:45:48,109 INFO [checkpoint.py:112] Loading checkpoint from tdnn/exp/epoch-99.pt + 2023-08-16 20:45:48,115 INFO [asr_datamodule.py:218] About to get test cuts + 2023-08-16 20:45:48,115 INFO [asr_datamodule.py:253] About to get test cuts + 2023-08-16 20:45:50,386 INFO [decode.py:203] batch 0/?, cuts processed until now is 4 + 2023-08-16 20:45:50,556 INFO [decode.py:240] The transcripts are stored in tdnn/exp/recogs-test_set.txt + 2023-08-16 20:45:50,557 INFO [utils.py:564] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ] + 2023-08-16 20:45:50,558 INFO [decode.py:248] Wrote detailed error stats to tdnn/exp/errs-test_set.txt + 2023-08-16 20:45:50,559 INFO [decode.py:315] Done! + +We can see that it produces an identical WER as before. + +We can also use it to decode files with the following command: + +.. code-block:: bash + + # ./tdnn/pretrained.py requires kaldifeat + # + # Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html + # for how to install kaldifeat + + pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html + + ./tdnn/pretrained.py \ + --checkpoint ./tdnn/exp/pretrained.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +The output is given below: + +.. code-block:: bash + + 2023-08-16 20:53:19,208 INFO [pretrained.py:136] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'checkpoint': './tdnn/exp/pretrained.pt', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']} + 2023-08-16 20:53:19,208 INFO [pretrained.py:142] device: cpu + 2023-08-16 20:53:19,208 INFO [pretrained.py:144] Creating model + 2023-08-16 20:53:19,212 INFO [pretrained.py:156] Loading HLG from ./data/lang_phone/HLG.pt + 2023-08-16 20:53:19,213 INFO [pretrained.py:160] Constructing Fbank computer + 2023-08-16 20:53:19,213 INFO [pretrained.py:170] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav'] + 2023-08-16 20:53:19,224 INFO [pretrained.py:176] Decoding started + 2023-08-16 20:53:19,304 INFO [pretrained.py:212] + download/waves_yesno/0_0_0_1_0_0_0_1.wav: + NO NO NO YES NO NO NO YES + + download/waves_yesno/0_0_1_0_0_0_1_0.wav: + NO NO YES NO NO NO YES NO + + + 2023-08-16 20:53:19,304 INFO [pretrained.py:214] Decoding Done + + +Export via torch.jit.script() +----------------------------- + +The command for this kind of export is + +.. code-block:: bash + + cd /tmp/icefall + export PYTHONPATH=/tmp/icefall:$PYTHONPATH + cd egs/yesno/ASR + + # assume that "--epoch 14 --avg 2" produces the lowest WER. + + ./tdnn/export.py --epoch 14 --avg 2 --jit true + +The output logs are given below: + +.. code-block:: bash + + 2023-08-16 20:47:44,666 INFO [export.py:76] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'jit': True} + 2023-08-16 20:47:44,667 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt + 2023-08-16 20:47:44,670 INFO [export.py:93] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt'] + 2023-08-16 20:47:44,677 INFO [export.py:100] Using torch.jit.script + 2023-08-16 20:47:44,843 INFO [export.py:104] Saved to tdnn/exp/cpu_jit.pt + +From the output logs we can see that the generated file is saved to ``tdnn/exp/cpu_jit.pt``. + +Don't be confused by the name ``cpu_jit.pt``. The ``cpu`` part means the model is moved to +CPU before exporting. That means, when you load it with: + +.. code-block:: bash + + torch.jit.load() + +you don't need to specify the argument `map_location `_ +and it resides on CPU by default. + +To use ``tdnn/exp/cpu_jit.pt`` with `icefall`_ to decode files, we can use: + +.. code-block:: bash + + # ./tdnn/jit_pretrained.py requires kaldifeat + # + # Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html + # for how to install kaldifeat + + pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html + + + ./tdnn/jit_pretrained.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +The output is given below: + +.. code-block:: bash + + 2023-08-16 20:56:00,603 INFO [jit_pretrained.py:121] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'nn_model': './tdnn/exp/cpu_jit.pt', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']} + 2023-08-16 20:56:00,603 INFO [jit_pretrained.py:127] device: cpu + 2023-08-16 20:56:00,603 INFO [jit_pretrained.py:129] Loading torchscript model + 2023-08-16 20:56:00,640 INFO [jit_pretrained.py:134] Loading HLG from ./data/lang_phone/HLG.pt + 2023-08-16 20:56:00,641 INFO [jit_pretrained.py:138] Constructing Fbank computer + 2023-08-16 20:56:00,641 INFO [jit_pretrained.py:148] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav'] + 2023-08-16 20:56:00,642 INFO [jit_pretrained.py:154] Decoding started + 2023-08-16 20:56:00,727 INFO [jit_pretrained.py:190] + download/waves_yesno/0_0_0_1_0_0_0_1.wav: + NO NO NO YES NO NO NO YES + + download/waves_yesno/0_0_1_0_0_0_1_0.wav: + NO NO YES NO NO NO YES NO + + + 2023-08-16 20:56:00,727 INFO [jit_pretrained.py:192] Decoding Done + +.. hint:: + + We provide only code for ``torch.jit.script()``. You can try ``torch.jit.trace()`` + if you want. + +Export via torch.onnx.export() +------------------------------ + +The command for this kind of export is + +.. code-block:: bash + + cd /tmp/icefall + export PYTHONPATH=/tmp/icefall:$PYTHONPATH + cd egs/yesno/ASR + + # tdnn/export_onnx.py requires onnx and onnxruntime + pip install onnx onnxruntime + + # assume that "--epoch 14 --avg 2" produces the lowest WER. + + ./tdnn/export_onnx.py \ + --epoch 14 \ + --avg 2 + +The output logs are given below: + +.. code-block:: bash + + 2023-08-16 20:59:20,888 INFO [export_onnx.py:83] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2} + 2023-08-16 20:59:20,888 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt + 2023-08-16 20:59:20,892 INFO [export_onnx.py:100] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt'] + ================ Diagnostic Run torch.onnx.export version 2.0.0 ================ + verbose: False, log level: Level.ERROR + ======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ======================== + + 2023-08-16 20:59:21,047 INFO [export_onnx.py:127] Saved to tdnn/exp/model-epoch-14-avg-2.onnx + 2023-08-16 20:59:21,047 INFO [export_onnx.py:136] meta_data: {'model_type': 'tdnn', 'version': '1', 'model_author': 'k2-fsa', 'comment': 'non-streaming tdnn for the yesno recipe', 'vocab_size': 4} + 2023-08-16 20:59:21,049 INFO [export_onnx.py:140] Generate int8 quantization models + 2023-08-16 20:59:21,075 INFO [onnx_quantizer.py:538] Quantization parameters for tensor:"/Transpose_1_output_0" not specified + 2023-08-16 20:59:21,081 INFO [export_onnx.py:151] Saved to tdnn/exp/model-epoch-14-avg-2.int8.onnx + +We can see from the logs that it generates two files: + + - ``tdnn/exp/model-epoch-14-avg-2.onnx`` (ONNX model with ``float32`` weights) + - ``tdnn/exp/model-epoch-14-avg-2.int8.onnx`` (ONNX model with ``int8`` weights) + +To use the generated ONNX model files for decoding with `onnxruntime`_, we can use + +.. code-block:: bash + + # ./tdnn/onnx_pretrained.py requires kaldifeat + # + # Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html + # for how to install kaldifeat + + pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html + + ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +The output is given below: + +.. code-block:: bash + + 2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:166] {'feature_dim': 23, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'nn_model': './tdnn/exp/model-epoch-14-avg-2.onnx', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']} + 2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:171] device: cpu + 2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:173] Loading onnx model ./tdnn/exp/model-epoch-14-avg-2.onnx + 2023-08-16 21:03:24,267 INFO [onnx_pretrained.py:176] Loading HLG from ./data/lang_phone/HLG.pt + 2023-08-16 21:03:24,270 INFO [onnx_pretrained.py:180] Constructing Fbank computer + 2023-08-16 21:03:24,273 INFO [onnx_pretrained.py:190] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav'] + 2023-08-16 21:03:24,279 INFO [onnx_pretrained.py:196] Decoding started + 2023-08-16 21:03:24,318 INFO [onnx_pretrained.py:232] + download/waves_yesno/0_0_0_1_0_0_0_1.wav: + NO NO NO YES NO NO NO YES + + download/waves_yesno/0_0_1_0_0_0_1_0.wav: + NO NO YES NO NO NO YES NO + + + 2023-08-16 21:03:24,318 INFO [onnx_pretrained.py:234] Decoding Done + +.. note:: + + To use the ``int8`` ONNX model for decoding, please use: + + .. code-block:: bash + + ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +For the more curious +-------------------- + +If you are wondering how to deploy the model without ``torch``, please +continue reading. We will show how to use `sherpa-onnx`_ to run the +exported ONNX models, which depends only on `onnxruntime`_ and does not +depend on ``torch``. + +In this tutorial, we will only demonstrate the usage of `sherpa-onnx`_ with the +pre-trained model of the `yesno`_ recipe. There are also other two frameworks +available: + + - `sherpa`_. It works with torchscript models. + - `sherpa-ncnn`_. It works with models exported using :ref:`icefall_export_to_ncnn` with `ncnn`_ + +Please see ``_ for further details. diff --git a/docs/source/for-dummies/training.rst b/docs/source/for-dummies/training.rst new file mode 100644 index 000000000..816ef2d3b --- /dev/null +++ b/docs/source/for-dummies/training.rst @@ -0,0 +1,39 @@ +.. _dummies_tutorial_training: + +Training +======== + +After :ref:`dummies_tutorial_data_preparation`, we can start training. + +The command to start the training is quite simple: + +.. code-block:: bash + + cd /tmp/icefall + export PYTHONPATH=/tmp/icefall:$PYTHONPATH + cd egs/yesno/ASR + + # We use CPU for training by setting the following environment variable + export CUDA_VISIBLE_DEVICES="" + + ./tdnn/train.py + +That's it! + +You can find the training logs below: + +.. literalinclude:: ./code/train-yesno.txt + +For the more curious +-------------------- + +.. code-block:: bash + + ./tdnn/train.py --help + +will print the usage information about ``./tdnn/train.py``. For instance, you +can specify the number of epochs to train and the location to save the training +results. + +The training text logs are saved in ``tdnn/exp/log`` while the tensorboard +logs are in ``tdnn/exp/tensorboard``. diff --git a/docs/source/index.rst b/docs/source/index.rst index 0fa8fdd1c..fb539d3f2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -20,6 +20,7 @@ speech recognition recipes using `k2 `_. :maxdepth: 2 :caption: Contents: + for-dummies/index.rst installation/index docker/index faqs diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py index 626473b6e..b23a2a381 100755 --- a/egs/yesno/ASR/tdnn/onnx_pretrained.py +++ b/egs/yesno/ASR/tdnn/onnx_pretrained.py @@ -6,6 +6,7 @@ This file shows how to use an ONNX model for decoding with onnxruntime. Usage: (1) Use a not quantized ONNX model, i.e., a float32 model + ./tdnn/onnx_pretrained.py \ --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ --HLG ./data/lang_phone/HLG.pt \ From 4d7f73ce65e2ce89c6be432ae2f973cb5597474f Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Mon, 28 Aug 2023 19:37:32 +0800 Subject: [PATCH 024/113] Add context biasing for zipformer recipe (#1204) * Add context biasing for zipformer recipe * support context biasing in modified_beam_search_LODR * fix context graph * Minor fixes --- .../beam_search.py | 33 +++++++ egs/librispeech/ASR/zipformer/decode.py | 88 +++++++++++++++---- icefall/context_graph.py | 43 ++++----- 3 files changed, 122 insertions(+), 42 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 97e259b40..3298568a3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -2389,6 +2389,7 @@ def modified_beam_search_LODR( LODR_lm_scale: float, LM: LmScorer, beam: int = 4, + context_graph: Optional[ContextGraph] = None, ) -> List[List[int]]: """This function implements LODR (https://arxiv.org/abs/2203.16776) with `modified_beam_search`. It uses a bi-gram language model as the estimate @@ -2457,6 +2458,7 @@ def modified_beam_search_LODR( state_cost=NgramLmStateCost( LODR_lm ), # state of the source domain ngram + context_state=None if context_graph is None else context_graph.root, ) ) @@ -2602,8 +2604,17 @@ def modified_beam_search_LODR( hyp_log_prob = topk_log_probs[k] # get score of current hyp new_token = topk_token_indexes[k] + + context_score = 0 + new_context_state = None if context_graph is None else hyp.context_state if new_token not in (blank_id, unk_id): + if context_graph is not None: + ( + context_score, + new_context_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + ys.append(new_token) state_cost = hyp.state_cost.forward_one_step(new_token) @@ -2619,6 +2630,7 @@ def modified_beam_search_LODR( hyp_log_prob += ( lm_score[new_token] * lm_scale + LODR_lm_scale * current_ngram_score + + context_score ) # add the lm score lm_score = scores[count] @@ -2637,10 +2649,31 @@ def modified_beam_search_LODR( state=state, lm_score=lm_score, state_cost=state_cost, + context_state=new_context_state, ) B[i].add(new_hyp) B = B + finalized_B + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + finalized_B = [HypothesisList() for _ in range(len(B))] + for i, hyps in enumerate(B): + for hyp in list(hyps): + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + finalized_B[i].add( + Hypothesis( + ys=hyp.ys, + log_prob=hyp.log_prob + context_score, + timestamp=hyp.timestamp, + context_state=new_context_state, + ) + ) + B = finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 2cc157e7a..3531d657f 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -97,6 +97,7 @@ Usage: import argparse import logging import math +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -122,7 +123,7 @@ from beam_search import ( ) from train import add_model_arguments, get_model, get_params -from icefall import LmScorer, NgramLm +from icefall import ContextGraph, LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -215,6 +216,7 @@ def get_parser(): - greedy_search - beam_search - modified_beam_search + - modified_beam_search_LODR - fast_beam_search - fast_beam_search_nbest - fast_beam_search_nbest_oracle @@ -251,7 +253,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. + Used only when --decoding-method is fast_beam_search_nbest_LG. It specifies the scale for n-gram LM scores. """, ) @@ -285,7 +287,7 @@ def get_parser(): type=int, default=1, help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", + Used only when --decoding-method is greedy_search""", ) parser.add_argument( @@ -347,6 +349,27 @@ def get_parser(): help="ID of the backoff symbol in the ngram LM", ) + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) add_model_arguments(parser) return parser @@ -359,6 +382,7 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, LM: Optional[LmScorer] = None, ngram_lm=None, ngram_lm_scale: float = 0.0, @@ -388,7 +412,7 @@ def decode_one_batch( The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. LM: A neural network language model. @@ -493,6 +517,7 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -515,6 +540,7 @@ def decode_one_batch( LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, + context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -578,16 +604,22 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} - elif params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): - ans = dict() - assert ans_dict is not None - for key, hyps in ans_dict.items(): - hyps = [sp.decode(hyp).split() for hyp in hyps] - ans[f"beam_size_{params.beam_size}_{key}"] = hyps - return ans + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -599,6 +631,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, LM: Optional[LmScorer] = None, ngram_lm=None, ngram_lm_scale: float = 0.0, @@ -618,7 +651,7 @@ def decode_dataset( The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search @@ -649,6 +682,7 @@ def decode_dataset( model=model, sp=sp, decoding_graph=decoding_graph, + context_graph=context_graph, word_table=word_table, batch=batch, LM=LM, @@ -744,6 +778,11 @@ def main(): ) params.res_dir = params.exp_dir / params.decoding_method + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: @@ -770,6 +809,12 @@ def main(): params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -952,6 +997,18 @@ def main(): decoding_graph = None word_table = None + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(line.strip()) + context_graph = ContextGraph(params.context_score) + context_graph.build(sp.encode(contexts)) + else: + context_graph = None + else: + context_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -976,6 +1033,7 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + context_graph=context_graph, LM=LM, ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, diff --git a/icefall/context_graph.py b/icefall/context_graph.py index c78de30f6..01836df04 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -29,7 +29,7 @@ class ContextState: token: int, token_score: float, node_score: float, - local_node_score: float, + output_score: float, is_end: bool, ): """Create a ContextState. @@ -40,16 +40,15 @@ class ContextState: The id of the root node is always 0. token: The token id. - score: + token_score: The bonus for each token during decoding, which will hopefully boost the token up to survive beam search. node_score: The accumulated bonus from root of graph to current node, it will be used to calculate the score for fail arc. - local_node_score: - The accumulated bonus from last ``end_node``(node with is_end true) - to current_node, it will be used to calculate the score for fail arc. - Node: The local_node_score of a ``end_node`` is 0. + output_score: + The total scores of matched phrases, sum of the node_score of all + the output node for current node. is_end: True if current token is the end of a context. """ @@ -57,7 +56,7 @@ class ContextState: self.token = token self.token_score = token_score self.node_score = node_score - self.local_node_score = local_node_score + self.output_score = output_score self.is_end = is_end self.next = {} self.fail = None @@ -93,7 +92,7 @@ class ContextGraph: token=-1, token_score=0, node_score=0, - local_node_score=0, + output_score=0, is_end=False, ) self.root.fail = self.root @@ -131,6 +130,7 @@ class ContextGraph: output = None break node.output = output + node.output_score += 0 if output is None else output.output_score queue.append(node) def build(self, token_ids: List[List[int]]): @@ -153,14 +153,13 @@ class ContextGraph: if token not in node.next: self.num_nodes += 1 is_end = i == len(tokens) - 1 + node_score = node.node_score + self.context_score node.next[token] = ContextState( id=self.num_nodes, token=token, token_score=self.context_score, - node_score=node.node_score + self.context_score, - local_node_score=0 - if is_end - else (node.local_node_score + self.context_score), + node_score=node_score, + output_score=node_score if is_end else 0, is_end=is_end, ) node = node.next[token] @@ -186,8 +185,6 @@ class ContextGraph: if token in state.next: node = state.next[token] score = node.token_score - if state.is_end: - score += state.node_score else: # token not matched # We will trace along the fail arc until it matches the token or reaching @@ -202,14 +199,9 @@ class ContextGraph: node = node.next[token] # The score of the fail path - score = node.node_score - state.local_node_score + score = node.node_score - state.node_score assert node is not None - matched_score = 0 - output = node.output - while output is not None: - matched_score += output.node_score - output = output.output - return (score + matched_score, node) + return (score + node.output_score, node) def finalize(self, state: ContextState) -> Tuple[float, ContextState]: """When reaching the end of the decoded sequence, we need to finalize @@ -227,8 +219,6 @@ class ContextGraph: """ # The score of the fail arc score = -state.node_score - if state.is_end: - score = 0 return (score, self.root) def draw( @@ -307,10 +297,8 @@ class ContextGraph: for token, node in current_node.next.items(): if node.id not in seen: node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".") - local_node_score = f"{node.local_node_score:.2f}".rstrip( - "0" - ).rstrip(".") - label = f"{node.id}/({node_score},{local_node_score})" + output_score = f"{node.output_score:.2f}".rstrip("0").rstrip(".") + label = f"{node.id}/({node_score}, {output_score})" if node.is_end: dot.node(str(node.id), label=label, **final_state_attr) else: @@ -391,6 +379,7 @@ if __name__ == "__main__": "HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE" "HISHE": 9, # "HIS", "S", "SHE", "HE" "SHED": 6, # "S", "SHE", "HE" + "SHELF": 6, # "S", "SHE", "HE" "HELL": 2, # "HE" "HELLO": 7, # "HE", "HELLO" "DHRHISQ": 4, # "HIS", "S" From 3a1ce5963b67413b5d274895a1156e20dc30c3be Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Tue, 29 Aug 2023 16:39:48 +0800 Subject: [PATCH 025/113] Minor fix for documentation (#1229) --- docs/source/decoding-with-langugage-models/LODR.rst | 5 ++++- docs/source/decoding-with-langugage-models/rescoring.rst | 5 ++++- .../source/decoding-with-langugage-models/shallow-fusion.rst | 5 ++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/docs/source/decoding-with-langugage-models/LODR.rst b/docs/source/decoding-with-langugage-models/LODR.rst index b6625ee1d..8cc1a624c 100644 --- a/docs/source/decoding-with-langugage-models/LODR.rst +++ b/docs/source/decoding-with-langugage-models/LODR.rst @@ -71,9 +71,12 @@ As the initial step, let's download the pre-trained model. .. code-block:: bash $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 - $ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp + $ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp $ git lfs pull --include "pretrained.pt" $ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded + $ cd ../data/lang_bpe_500 + $ git lfs pull --include bpe.model + $ cd ../../.. To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command: diff --git a/docs/source/decoding-with-langugage-models/rescoring.rst b/docs/source/decoding-with-langugage-models/rescoring.rst index 02eba9129..4cabaa432 100644 --- a/docs/source/decoding-with-langugage-models/rescoring.rst +++ b/docs/source/decoding-with-langugage-models/rescoring.rst @@ -34,9 +34,12 @@ As the initial step, let's download the pre-trained model. .. code-block:: bash $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 - $ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp + $ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp $ git lfs pull --include "pretrained.pt" $ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded + $ cd ../data/lang_bpe_500 + $ git lfs pull --include bpe.model + $ cd ../../.. As usual, we first test the model's performance without external LM. This can be done via the following command: diff --git a/docs/source/decoding-with-langugage-models/shallow-fusion.rst b/docs/source/decoding-with-langugage-models/shallow-fusion.rst index f15e3f1d9..684fefeb4 100644 --- a/docs/source/decoding-with-langugage-models/shallow-fusion.rst +++ b/docs/source/decoding-with-langugage-models/shallow-fusion.rst @@ -32,9 +32,12 @@ As the initial step, let's download the pre-trained model. .. code-block:: bash $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 - $ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp + $ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp $ git lfs pull --include "pretrained.pt" $ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded + $ cd ../data/lang_bpe_500 + $ git lfs pull --include bpe.model + $ cd ../../.. To test the model, let's have a look at the decoding results without using LM. This can be done via the following command: From 8fcadb68a7cde093069e89830832e1ac728338fe Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Wed, 30 Aug 2023 22:31:05 -0400 Subject: [PATCH 026/113] Missing definitions in scaling.py added (#1232) --- egs/libricss/SURT/dprnn_zipformer/scaling.py | 1577 +++++++++++++++++- 1 file changed, 1576 insertions(+), 1 deletion(-) mode change 120000 => 100644 egs/libricss/SURT/dprnn_zipformer/scaling.py diff --git a/egs/libricss/SURT/dprnn_zipformer/scaling.py b/egs/libricss/SURT/dprnn_zipformer/scaling.py deleted file mode 120000 index 5f9be9fe0..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/libricss/SURT/dprnn_zipformer/scaling.py b/egs/libricss/SURT/dprnn_zipformer/scaling.py new file mode 100644 index 000000000..4040a7b89 --- /dev/null +++ b/egs/libricss/SURT/dprnn_zipformer/scaling.py @@ -0,0 +1,1576 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import random +from typing import Optional, Tuple, Union + +import torch +import torch.backends.cudnn.rnn as rnn +import torch.nn as nn +from torch import _VF, Tensor + +from icefall.utils import is_jit_tracing + + +class ActivationBalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + xgt0 = x > 0 + if sign_factor is None: + ctx.save_for_backward(xgt0, scale_factor) + else: + ctx.save_for_backward(xgt0, scale_factor, sign_factor) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + if len(ctx.saved_tensors) == 3: + xgt0, scale_factor, sign_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + scale_factor = scale_factor.unsqueeze(-1) + sign_factor = sign_factor.unsqueeze(-1) + factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + else: + xgt0, scale_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + scale_factor = scale_factor.unsqueeze(-1) + factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + neg_delta_grad = x_grad.abs() * factor + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) + + +def _compute_scale_factor( + x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float, +) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) + + if min_abs == 0.0: + below_threshold = 0.0 + else: + # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if + # x_abs)_mean , min_abs. + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( + min=0, max=max_factor + ) + + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( + min=0, max=max_factor + ) + + return below_threshold - above_threshold + + +def _compute_sign_factor( + x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float, +) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) + if min_positive == 0.0: + factor1 = 0.0 + else: + # 0 if proportion_positive >= min_positive, else can be + # as large as max_factor. + factor1 = ( + (min_positive - proportion_positive) * (gain_factor / min_positive) + ).clamp_(min=0, max=max_factor) + + if max_positive == 1.0: + factor2 = 0.0 + else: + # 0 if self.proportion_positive <= max_positive, else can be + # as large as -max_factor. + factor2 = ( + (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) + ).clamp_(min=0, max=max_factor) + sign_factor = factor1 - factor2 + # require min_positive != 0 or max_positive != 1: + assert not isinstance(sign_factor, float) + return sign_factor + + +class ActivationScaleBalancerFunction(torch.autograd.Function): + """ + This object is used in class ActivationBalancer when the user specified + min_positive=0, max_positive=1, so there are no constraints on the signs + of the activations and only the absolute value has a constraint. + """ + + @staticmethod + def forward( + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + xgt0 = x > 0 + ctx.save_for_backward(xgt0, sign_factor, scale_factor) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + xgt0, sign_factor, scale_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + sign_factor = sign_factor.unsqueeze(-1) + scale_factor = scale_factor.unsqueeze(-1) + + factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + neg_delta_grad = x_grad.abs() * factor + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) + + +class RandomClampFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float, + ) -> Tensor: + x_clamped = torch.clamp(x, min=min, max=max) + mask = torch.rand_like(x) < prob + ans = torch.where(mask, x_clamped, x) + if x.requires_grad: + ctx.save_for_backward(ans == x) + ctx.reflect = reflect + if reflect != 0.0: + ans = ans * (1.0 + reflect) - (x * reflect) + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: + (is_same,) = ctx.saved_tensors + x_grad = ans_grad * is_same.to(ans_grad.dtype) + reflect = ctx.reflect + if reflect != 0.0: + x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) + return x_grad, None, None, None, None + + +def random_clamp( + x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0, +): + return RandomClampFunction.apply(x, min, max, prob, reflect) + + +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: + """ + A randomized way of casting a floating point value to half precision. + """ + if x.dtype == torch.float16: + return x + x_abs = x.abs() + is_too_small = x_abs < min_abs + # for elements where is_too_small is true, random_val will contain +-min_abs with + # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, + # for those elements]. + random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) + return torch.where(is_too_small, random_val, x).to(torch.float16) + + +class RandomGradFunction(torch.autograd.Function): + """ + Does nothing in forward pass; in backward pass, gets rid of very small grads using + randomized approach that preserves expectations (intended to reduce roundoff). + """ + + @staticmethod + def forward(ctx, x: Tensor, min_abs: float) -> Tensor: + ctx.min_abs = min_abs + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: + if ans_grad.dtype == torch.float16: + return ( + random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), + None, + ) + else: + return ans_grad, None + + +class RandomGrad(torch.nn.Module): + """ + Gets rid of very small gradients using an expectation-preserving method, intended to increase + accuracy of training when using amp (automatic mixed precision) + """ + + def __init__(self, min_abs: float = 5.0e-06): + super(RandomGrad, self).__init__() + self.min_abs = min_abs + + def forward(self, x: Tensor): + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): + return x + else: + return RandomGradFunction.apply(x, self.min_abs) + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.float16) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + (ans,) = ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, dim: int): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x.softmax(dim) + + return SoftmaxFunction.apply(x, dim) + + +class MaxEigLimiterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: + ctx.channel_dim = channel_dim + ctx.grad_scale = grad_scale + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) + return x + + @staticmethod + def backward(ctx, x_grad, *args): + with torch.enable_grad(): + (x_orig, coeffs, new_direction) = ctx.saved_tensors + x_orig.requires_grad = True + num_channels = x_orig.shape[ctx.channel_dim] + x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) + new_direction.requires_grad = False + x = x - x.mean(dim=0) + x_var = (x**2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual**2).mean() + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. This is to be minimized. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + variance_proportion.backward() + x_orig_grad = x_orig.grad + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) + return x_grad + x_extra_grad.detach(), None, None, None, None + + +class GradientFilterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + batch_dim: int, # e.g., 1 + threshold: float, # e.g., 10.0 + *params: Tensor, # module parameters + ) -> Tuple[Tensor, ...]: + if x.requires_grad: + if batch_dim < 0: + batch_dim += x.ndim + ctx.batch_dim = batch_dim + ctx.threshold = threshold + return (x,) + params + + @staticmethod + def backward( + ctx, + x_grad: Tensor, + *param_grads: Tensor, + ) -> Tuple[Tensor, ...]: + eps = 1.0e-20 + dim = ctx.batch_dim + norm_dims = [d for d in range(x_grad.ndim) if d != dim] + norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() + median_norm = norm_of_batch.median() + + cutoff = median_norm * ctx.threshold + inv_mask = (cutoff + norm_of_batch) / (cutoff + eps) + mask = 1.0 / (inv_mask + eps) + x_grad = x_grad * mask + + avg_mask = 1.0 / (inv_mask.mean() + eps) + param_grads = [avg_mask * g for g in param_grads] + + return (x_grad, None, None) + tuple(param_grads) + + +class GradientFilter(torch.nn.Module): + """This is used to filter out elements that have extremely large gradients + in batch and the module parameters with soft masks. + + Args: + batch_dim (int): + The batch dimension. + threshold (float): + For each element in batch, its gradient will be + filtered out if the gradient norm is larger than + `grad_norm_threshold * median`, where `median` is the median + value of gradient norms of all elememts in batch. + """ + + def __init__(self, batch_dim: int = 1, threshold: float = 10.0): + super(GradientFilter, self).__init__() + self.batch_dim = batch_dim + self.threshold = threshold + + def forward(self, x: Tensor, *params: Tensor) -> Tuple[Tensor, ...]: + if torch.jit.is_scripting() or is_jit_tracing(): + return (x,) + params + else: + return GradientFilterFunction.apply( + x, + self.batch_dim, + self.threshold, + *params, + ) + + +class BasicNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + So the idea is to introduce this large constant value as an explicit + parameter, that takes the role of the "eps" in LayerNorm, so the network + doesn't have to do this trick. We make the "eps" learnable. + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + eps_min: float + eps_max: float + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + eps_min: float = -3.0, + eps_max: float = 3.0, + ) -> None: + super(BasicNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + if learn_eps: + self.eps = nn.Parameter(torch.tensor(eps).log().detach()) + else: + self.register_buffer("eps", torch.tensor(eps).log().detach()) + self.eps_min = eps_min + self.eps_max = eps_max + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + eps = self.eps + if self.training and random.random() < 0.25: + # with probability 0.25, in training mode, clamp eps between the min + # and max; this will encourage it to learn parameters within the + # allowed range by making parameters that are outside the allowed + # range noisy. + + # gradients to allow the parameter to get back into the allowed + # region if it happens to exit it. + eps = eps.clamp(min=self.eps_min, max=self.eps_max) + scales = ( + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() + ) ** -0.5 + return x * scales + + +class ScaledEmbedding(nn.Module): + r"""This is a modified version of nn.Embedding that introduces a learnable scale + on the parameters. Note: due to how we initialize it, it's best used with + schedulers like Noam that have a warmup period. + + It is a simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` + (initialized to zeros) whenever it encounters the index. + scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + initial_speed (float, optional): This affects how fast the parameter will + learn near the start of training; you can set it to a value less than + one if you suspect that a module is contributing to instability near + the start of training. Note: regardless of the use of this option, + it's best to use schedulers like Noam that have a warm-up period. + Alternatively you can set it to more than 1 if you want it to + initially train faster. Must be greater than 0. + + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + With :attr:`padding_idx` set, the embedding vector at + :attr:`padding_idx` is initialized to all zeros. However, note that this + vector can be modified afterwards, e.g., using a customized + initialization method, and thus changing the vector used to pad the + output. The gradient for this vector from :class:`~torch.nn.Embedding` + is always zero. + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0,2,0,5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + + """ + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] + + num_embeddings: int + embedding_dim: int + padding_idx: int + scale_grad_by_freq: bool + weight: Tensor + sparse: bool + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + initial_speed: float = 1.0, + ) -> None: + super(ScaledEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.scale_grad_by_freq = scale_grad_by_freq + + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.sparse = sparse + + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters(initial_speed) + + def reset_parameters(self, initial_speed: float = 1.0) -> None: + std = 0.1 / initial_speed + nn.init.normal_(self.weight, std=std) + nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) + + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + F = torch.nn.functional + scale = self.scale.exp() + if input.numel() < self.num_embeddings: + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) + else: + return F.embedding( + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) + + def extra_repr(self) -> str: + # s = "{num_embeddings}, {embedding_dim}, scale={scale}" + s = "{num_embeddings}, {embedding_dim}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + if self.scale_grad_by_freq is not False: + s += ", scale_grad_by_freq={scale_grad_by_freq}" + if self.sparse is not False: + s += ", sparse=True" + return s.format(**self.__dict__) + + +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv1d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +class ScaledLSTM(nn.LSTM): + # See docs for ScaledLinear. + # This class implements LSTM with scaling mechanism, using `torch._VF.lstm` + # Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + grad_norm_threshold: float = 10.0, + **kwargs, + ): + super(ScaledLSTM, self).__init__(*args, **kwargs) + initial_scale = torch.tensor(initial_scale).log() + self._scales_names = [] + self._scales = [] + self.batch_dim = 0 if self.batch_first else 1 + self.num_directions = 1 + int(self.bidirectional) + for name in self._flat_weights_names: + scale_name = name + "_scale" + self._scales_names.append(scale_name) + param = nn.Parameter(initial_scale.clone().detach()) + setattr(self, scale_name, param) + self._scales.append(param) + + self.grad_filter = GradientFilter( + batch_dim=self.batch_dim, threshold=grad_norm_threshold + ) + + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in base class + + def _reset_parameters(self, initial_speed: float): + std = 0.1 / initial_speed + a = (3**0.5) * std + scale = self.hidden_size**-0.5 + v = scale / std + for idx, name in enumerate(self._flat_weights_names): + if "weight" in name: + nn.init.uniform_(self._flat_weights[idx], -a, a) + with torch.no_grad(): + self._scales[idx] += torch.tensor(v).log() + elif "bias" in name: + nn.init.constant_(self._flat_weights[idx], 0.0) + + def _flatten_parameters(self, flat_weights) -> None: + """Resets parameter data pointer so that they can use faster code paths. + + Right now, this works only if the module is on the GPU and cuDNN is enabled. + Otherwise, it's a no-op. + + This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa + """ + # Short-circuits if _flat_weights is only partially instantiated + if len(flat_weights) != len(self._flat_weights_names): + return + + for w in flat_weights: + if not isinstance(w, Tensor): + return + # Short-circuits if any tensor in flat_weights is not acceptable to cuDNN + # or the tensors in flat_weights are of different dtypes + + first_fw = flat_weights[0] + dtype = first_fw.dtype + for fw in flat_weights: + if ( + not isinstance(fw.data, Tensor) + or not (fw.data.dtype == dtype) + or not fw.data.is_cuda + or not torch.backends.cudnn.is_acceptable(fw.data) + ): + return + + # If any parameters alias, we fall back to the slower, copying code path. This is + # a sufficient check, because overlapping parameter buffers that don't completely + # alias would break the assumptions of the uniqueness check in + # Module.named_parameters(). + unique_data_ptrs = set(p.data_ptr() for p in flat_weights) + if len(unique_data_ptrs) != len(flat_weights): + return + + with torch.cuda.device_of(first_fw): + + # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is + # an inplace operation on self._flat_weights + with torch.no_grad(): + if torch._use_cudnn_rnn_flatten_weight(): + num_weights = 4 if self.bias else 2 + if self.proj_size > 0: + num_weights += 1 + torch._cudnn_rnn_flatten_weight( + flat_weights, + num_weights, + self.input_size, + rnn.get_cudnn_mode(self.mode), + self.hidden_size, + self.proj_size, + self.num_layers, + self.batch_first, + bool(self.bidirectional), + ) + + def _get_flat_weights(self): + """Get scaled weights, and resets their data pointer.""" + flat_weights = [] + for idx in range(len(self._flat_weights_names)): + flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) + self._flatten_parameters(flat_weights) + return flat_weights + + def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): + # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa + # The change for calling `_VF.lstm()` is: + # self._flat_weights -> self._get_flat_weights() + if hx is None: + h_zeros = torch.zeros( + self.num_layers * self.num_directions, + input.size(self.batch_dim), + self.proj_size if self.proj_size > 0 else self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers * self.num_directions, + input.size(self.batch_dim), + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (h_zeros, c_zeros) + + self.check_forward_args(input, hx, None) + + flat_weights = self._get_flat_weights() + input, *flat_weights = self.grad_filter(input, *flat_weights) + + result = _VF.lstm( + input, + hx, + flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + + output = result[0] + hidden = result[1:] + return output, hidden + + +class ActivationBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.02]. + sign_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_positive and max_positive + are violated. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + min_prob: determines the minimum probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. Early in training we may use + higher probabilities than this; it will decay to this value. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, + ): + super(ActivationBalancer, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.max_factor = max_factor + self.min_abs = min_abs + self.max_abs = max_abs + self.min_prob = min_prob + self.sign_gain_factor = sign_gain_factor + self.scale_gain_factor = scale_gain_factor + + # count measures how many times the forward() function has been called. + # We occasionally sync this to a tensor called `count`, that exists to + # make sure it is synced to disk when we load and save the model. + self.cpu_count = 0 + self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): + return _no_op(x) + + count = self.cpu_count + self.cpu_count += 1 + + if random.random() < 0.01: + # Occasionally sync self.cpu_count with self.count. + # count affects the decay of 'prob'. don't do this on every iter, + # because syncing with the GPU is slow. + self.cpu_count = max(self.cpu_count, self.count.item()) + self.count.fill_(self.cpu_count) + + # the prob of doing some work exponentially decreases from 0.5 till it hits + # a floor at min_prob (==0.1, by default) + prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) + + if random.random() < prob: + sign_gain_factor = 0.5 + if self.min_positive != 0.0 or self.max_positive != 1.0: + sign_factor = _compute_sign_factor( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor, + ) + else: + sign_factor = None + + scale_factor = _compute_scale_factor( + x.detach(), + self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor, + ) + return ActivationBalancerFunction.apply( + x, + scale_factor, + sign_factor, + self.channel_dim, + ) + else: + return _no_op(x) + + +def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss) + # you must use x for something, or this will be ineffective. + return x + + +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, :: dim + 1] + assert x.shape == (batch, dim) + return x + + +def _whitening_metric(x: Tensor, num_groups: int): + """ + Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of + of the centered feature covariance are the same within each group's covariance matrix + and also between groups. + Args: + x: a Tensor of shape (*, num_channels) + num_groups: the number of groups of channels, a number >=1 that divides num_channels + Returns: + Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and + greater than 1.0 otherwise. + """ + assert x.dtype != torch.float16 + x = x.reshape(-1, x.shape[-1]) + (num_frames, num_channels) = x.shape + assert num_channels % num_groups == 0 + channels_per_group = num_channels // num_groups + x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) + # x now has shape (num_groups, num_frames, channels_per_group) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + x_covar_mean_diag = _diag(x_covar).mean() + # the following expression is what we'd get if we took the matrix product + # of each covariance and measured the mean of its trace, i.e. + # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) + # this metric will be >= 1.0; the larger it is, the less 'white' the data was. + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) + return metric + + +class WhiteningPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, x: Tensor, num_groups: int, whitening_limit: float, grad_scale: float + ) -> Tensor: + ctx.save_for_backward(x) + ctx.num_groups = num_groups + ctx.whitening_limit = whitening_limit + ctx.grad_scale = grad_scale + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x_detached = x_orig.to(torch.float32).detach() + x_detached.requires_grad = True + + metric = _whitening_metric(x_detached, ctx.num_groups) + + if random.random() < 0.005 or __name__ == "__main__": + logging.info( + f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" + ) + + (metric - ctx.whitening_limit).relu().backward() + penalty_grad = x_detached.grad + scale = ctx.grad_scale * ( + x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) + ) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None, None, None + + +class Whiten(nn.Module): + def __init__( + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float, float]], + grad_scale: float, + ): + """ + Args: + num_groups: the number of groups to divide the channel dim into before + whitening. We will attempt to make the feature covariance + within each group, after mean subtraction, as "white" as possible, + while having the same trace across all groups. + whitening_limit: a value greater than 1.0, that dictates how much + freedom we have to violate the constraints. 1.0 would mean perfectly + white, with exactly the same trace across groups; larger values + give more freedom. E.g. 2.0. + prob: the probability with which we apply the gradient modification + (also affects the grad scale). May be supplied as a float, + or as a pair (min_prob, max_prob) + + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights. + E.g. 0.02 (you may want to use smaller values than this if prob is large) + """ + super(Whiten, self).__init__() + assert num_groups >= 1 + assert whitening_limit >= 1 + assert grad_scale >= 0 + self.num_groups = num_groups + self.whitening_limit = whitening_limit + if isinstance(prob, float): + assert 0 < prob <= 1 + self.prob = prob + else: + (self.min_prob, self.max_prob) = prob + assert 0 < self.min_prob < self.max_prob <= 1 + self.prob = self.max_prob + + self.grad_scale = grad_scale + + def forward(self, x: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the input unmodified. + In the backward pass, it will modify the gradients to ensure that the + distribution in each group has close to (lambda times I) as the covariance + after mean subtraction, with the same lambda across groups. + For whitening_limit > 1, there will be more freedom to violate this + constraint. + + Args: + x: the input of shape (*, num_channels) + + Returns: + x, unmodified. You should make sure + you use the returned value, or the graph will be freed + and nothing will happen in backprop. + """ + if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: + return _no_op(x) + else: + if hasattr(self, "min_prob") and random.random() < 0.25: + # occasionally switch between min_prob and max_prob, based on whether + # we are above or below the threshold. + if ( + _whitening_metric(x.to(torch.float32), self.num_groups) + > self.whitening_limit + ): + # there would be a change to the grad. + self.prob = self.max_prob + else: + self.prob = self.min_prob + + return WhiteningPenaltyFunction.apply( + x, self.num_groups, self.whitening_limit, self.grad_scale + ) + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor): + ctx.y_shape = y.shape + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + ) + + +def with_loss(x, y): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y) + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + +class MaxEig(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to discourage + that any given direction in activation space accounts for more than + a specified proportion of the covariance (e.g. 0.2). + + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + max_var_per_eig: the maximum proportion of the variance of the + features/channels, after mean subtraction, that can come from + any given eigenvalue. + min_prob: the minimum probability with which we apply this during any invocation + of forward(), assuming last time we applied the constraint it was + not active; supplied for speed. + scale: determines the scale with which we modify the gradients, relative + to the existing / unmodified gradients + """ + + def __init__( + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, + ): + super(MaxEig, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.scale = scale + assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels + self.max_var_per_eig = max_var_per_eig + + # we figure out the dominant direction using the power method: starting with + # a random vector, keep multiplying by the covariance and renormalizing. + with torch.no_grad(): + # arbitrary.. would use randn() but want to leave the rest of the model's + # random parameters unchanged for comparison + direction = torch.arange(num_channels).to(torch.float) + direction = direction / direction.norm() + self.register_buffer("max_eig_direction", direction) + + self.min_prob = min_prob + # cur_prob is the current probability we'll use to apply the ActivationBalancer. + # We'll regress this towards prob, each tiem we try to apply it and it is not + # active. + self.cur_prob = 1.0 + + def forward(self, x: Tensor) -> Tensor: + if ( + torch.jit.is_scripting() + or self.max_var_per_eig <= 0 + or random.random() > self.cur_prob + or torch.jit.is_tracing() + ): + return _no_op(x) + + with torch.cuda.amp.autocast(enabled=False): + eps = 1.0e-20 + orig_x = x + x = x.to(torch.float32) + with torch.no_grad(): + x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) + x = x - x.mean(dim=0) + new_direction, coeffs = self._find_direction_coeffs( + x, self.max_eig_direction + ) + x_var = (x**2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual**2).mean() + + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + + # ensure new direction is nonzero even if x == 0, by including `direction`. + self._set_direction(0.1 * self.max_eig_direction + new_direction) + + if random.random() < 0.01 or __name__ == "__main__": + logging.info( + f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" + ) + + if variance_proportion >= self.max_var_per_eig: + # The constraint is active. Note, we should quite rarely + # reach here, only near the beginning of training if we are + # starting to diverge, should this constraint be active. + cur_prob = self.cur_prob + self.cur_prob = 1.0 # next time, do the update with probability 1.0. + return MaxEigLimiterFunction.apply( + orig_x, coeffs, new_direction, self.channel_dim, self.scale + ) + else: + # let self.cur_prob exponentially approach self.min_prob, as + # long as the constraint is inactive. + self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob + return orig_x + + def _set_direction(self, direction: Tensor): + """ + Sets self.max_eig_direction to a normalized version of `direction` + """ + direction = direction.detach() + direction = direction / direction.norm() + direction_sum = direction.sum().item() + if direction_sum - direction_sum == 0: # no inf/nan + self.max_eig_direction[:] = direction + else: + logging.info( + f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}" + ) + + def _find_direction_coeffs( + self, x: Tensor, prev_direction: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. + + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ + (num_frames, num_channels) = x.shape + assert num_channels > 1 and num_frames > 1 + assert prev_direction.shape == (num_channels,) + # `coeffs` are the coefficients of `prev_direction` in x. + # actually represent the coeffs up to a constant positive factor. + coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) + return cur_direction, coeffs + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + x_dtype = x.dtype + if x.dtype == torch.float16: + x = x.to(torch.float32) + + s = torch.sigmoid(x - 1.0) + y = x * s + + if requires_grad: + deriv = y * (1 - s) + s + # notes on derivative of x * sigmoid(x - 1): + # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 + # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund + # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. + # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which + # floors), should be expectation-preserving. + floor = -0.043637 + ceil = 1.2 + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.043637 + ceil = 1.2 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class DoubleSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x * torch.sigmoid(x - 1.0) + return DoubleSwishFunction.apply(x) + + +def _test_max_eig(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + num_channels = 128 + m = MaxEig( + num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig + ) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad, atol=1.0e-02) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_whiten(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"_test_whiten(): proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + num_channels = 128 + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_activation_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + probs.numel(), + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_sign: x = ", x) + print("_test_activation_balancer_sign: y grad = ", y_grad) + print("_test_activation_balancer_sign: x grad = ", x.grad) + + +def _test_activation_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + magnitudes.numel(), + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + min_prob=1.0, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_magnitude: x = ", x) + print("_test_activation_balancer_magnitude: y grad = ", y_grad) + print("_test_activation_balancer_magnitude: x grad = ", x.grad) + + +def _test_basic_norm(): + num_channels = 128 + m = BasicNorm(num_channels=num_channels, channel_dim=1) + + x = torch.randn(500, num_channels) + + y = m(x) + + assert y.shape == x.shape + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() + print("x rms = ", x_rms) + print("y rms = ", y_rms) + assert y_rms < x_rms + assert y_rms > 0.5 * x_rms + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = DoubleSwish() + + tol = (1.2 - (-0.043637)) / 255.0 + torch.autograd.gradcheck(m, x, atol=tol) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:, 0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:, 0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_softmax() + _test_whiten() + _test_max_eig() + _test_activation_balancer_sign() + _test_activation_balancer_magnitude() + _test_basic_norm() + _test_double_swish_deriv() From 9ef8145fa3c6e8f45fa8ad8e8e4d348062b84ee4 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 4 Sep 2023 17:56:05 +0800 Subject: [PATCH 027/113] minor fixes (#1240) --- egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py | 1 + egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py | 1 + 2 files changed, 2 insertions(+) diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py index 20d7341db..1af08fee2 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py @@ -28,6 +28,7 @@ from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWri # even when we are not invoking the main (e.g. when spawning subprocesses). torch.set_num_threads(1) torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") def compute_fbank_wenetspeech_dev_test(): diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py index 1b257fb70..99d39bbdc 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py @@ -37,6 +37,7 @@ from lhotse import ( # even when we are not invoking the main (e.g. when spawning subprocesses). torch.set_num_threads(1) torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") def get_parser(): From d50a9ea03055232e742a753dd5e5e4cad914caa6 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 7 Sep 2023 16:34:53 +0800 Subject: [PATCH 028/113] doc str fixes (#1241) --- .../ASR/pruned_transducer_stateless7/compute_ali.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py index 8bcb56d62..27ef0a244 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py @@ -26,7 +26,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 @@ -52,12 +52,12 @@ import torch import torch.nn as nn from alignment import batch_force_alignment from asr_datamodule import LibriSpeechAsrDataModule -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp from lhotse import CutSet from lhotse.serialization import SequentialJsonlWriter from lhotse.supervision import AlignmentItem +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp def get_parser(): From c912bd65d0c301233e8d18fb1e1ea0e9c4c245d5 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 7 Sep 2023 18:48:27 +0800 Subject: [PATCH 029/113] Update run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh (#1242) --- .../run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh b/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh index c8d9c6b77..b61a9d7b6 100755 --- a/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh +++ b/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh @@ -29,6 +29,9 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == ls -lh data/fbank ls -lh pruned_transducer_stateless2/exp + ln -s data/fbank/cuts_DEV.jsonl.gz data/fbank/gigaspeech_cuts_DEV.jsonl.gz + ln -s data/fbank/cuts_TEST.jsonl.gz data/fbank/gigaspeech_cuts_TEST.jsonl.gz + log "Decoding dev and test" # use a small value for decoding with CPU From 49a4b672884213809cc04df2caab6c37cee92c22 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 7 Sep 2023 19:48:46 +0800 Subject: [PATCH 030/113] fixed a CI test issue related to python version (#1243) --- .github/workflows/run-aishell-2022-06-20.yml | 2 +- .github/workflows/run-gigaspeech-2022-05-13.yml | 2 +- .github/workflows/run-librispeech-2022-03-12.yml | 2 +- .github/workflows/run-librispeech-2022-04-29.yml | 2 +- .github/workflows/run-librispeech-2022-05-13.yml | 2 +- .../run-librispeech-pruned-transducer-stateless3-2022-05-13.yml | 2 +- ...n-librispeech-streaming-transducer-stateless2-2022-06-26.yml | 2 +- .../run-librispeech-transducer-stateless2-2022-04-19.yml | 2 +- .github/workflows/run-pretrained-conformer-ctc.yml | 2 +- .../run-pretrained-transducer-stateless-librispeech-100h.yml | 2 +- ...etrained-transducer-stateless-librispeech-multi-datasets.yml | 2 +- .../run-pretrained-transducer-stateless-modified-2-aishell.yml | 2 +- .../run-pretrained-transducer-stateless-modified-aishell.yml | 2 +- .github/workflows/run-pretrained-transducer-stateless.yml | 2 +- .github/workflows/run-pretrained-transducer.yml | 2 +- 15 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml index d14196f38..53fcb2c03 100644 --- a/.github/workflows/run-aishell-2022-06-20.yml +++ b/.github/workflows/run-aishell-2022-06-20.yml @@ -45,7 +45,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml index 0e47f7538..3121520c1 100644 --- a/.github/workflows/run-gigaspeech-2022-05-13.yml +++ b/.github/workflows/run-gigaspeech-2022-05-13.yml @@ -44,7 +44,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml index 3edbe43ec..f092e3c80 100644 --- a/.github/workflows/run-librispeech-2022-03-12.yml +++ b/.github/workflows/run-librispeech-2022-03-12.yml @@ -44,7 +44,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml index bb44a073b..f8f4d9977 100644 --- a/.github/workflows/run-librispeech-2022-04-29.yml +++ b/.github/workflows/run-librispeech-2022-04-29.yml @@ -44,7 +44,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false diff --git a/.github/workflows/run-librispeech-2022-05-13.yml b/.github/workflows/run-librispeech-2022-05-13.yml index e7b53b21c..dc20185da 100644 --- a/.github/workflows/run-librispeech-2022-05-13.yml +++ b/.github/workflows/run-librispeech-2022-05-13.yml @@ -44,7 +44,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml index bf73d4f18..3fb0920bc 100644 --- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml +++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml @@ -44,7 +44,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false diff --git a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml index 6ea308468..67a6f6fc4 100644 --- a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml +++ b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml @@ -44,7 +44,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml index 9fe2f0389..35ca08a31 100644 --- a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml +++ b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml @@ -44,7 +44,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml index bcd326b9d..6151a5a14 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-conformer-ctc.yml @@ -34,7 +34,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml index 1e5b25f5c..f8caee8e5 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml @@ -43,7 +43,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml index 9063c0ed6..7c3910eb8 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml @@ -43,7 +43,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml index 2d24528d3..ce6d6f92d 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml @@ -34,7 +34,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml index 761b26131..f0cebd94a 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml @@ -34,7 +34,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml index e46b9a849..1b69b97bf 100644 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -43,7 +43,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml index 190e446bc..91d87f1c9 100644 --- a/.github/workflows/run-pretrained-transducer.yml +++ b/.github/workflows/run-pretrained-transducer.yml @@ -34,7 +34,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] fail-fast: false From 3199058194a48d45aeee740f2aa9bdbef0bec29d Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 9 Sep 2023 21:25:26 +0800 Subject: [PATCH 031/113] enable `sclite_mode` for swbd scoring (#1239) --- icefall/utils.py | 3 ++- requirements-ci.txt | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/icefall/utils.py b/icefall/utils.py index b01cd2770..947d79438 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -493,6 +493,7 @@ def write_error_stats( test_set_name: str, results: List[Tuple[str, str]], enable_log: bool = True, + sclite_mode: bool = False, ) -> float: """Write statistics based on predicted results and reference transcripts. @@ -538,7 +539,7 @@ def write_error_stats( num_corr = 0 ERR = "*" for cut_id, ref, hyp in results: - ali = kaldialign.align(ref, hyp, ERR) + ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) for ref_word, hyp_word in ali: if ref_word == ERR: ins[hyp_word] += 1 diff --git a/requirements-ci.txt b/requirements-ci.txt index 3c2eb5f65..21d33001c 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -15,7 +15,7 @@ graphviz==0.19.1 git+https://github.com/lhotse-speech/lhotse kaldilm==1.11 -kaldialign==0.2 +kaldialign==0.7.1 sentencepiece==0.1.96 tensorboard==2.8.0 typeguard==2.13.3 From 0f1bc6f8af63d585436837b2b14f5075cd680480 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 13 Sep 2023 11:57:05 +0800 Subject: [PATCH 032/113] Multi_zh-Hans Recipe (#1238) * Init commit for recipes trained on multiple zh datasets. * fbank extraction for thchs30 * added support for aishell1 * added support for aishell-2 * fixes * fixes * fixes * added support for stcmds and primewords * fixes * added support for magicdata script for fbank computation not done yet * added script for magicdata fbank computation * file permission fixed * updated for the wenetspeech recipe * updated * Update preprocess_kespeech.py * updated * updated * updated * updated * file permission fixed * updated paths * fixes * added support for kespeech dev/test set fbank computation * fixes for file permission * refined support for KeSpeech * added scripts for BPE model training * updated * init commit for the multi_zh-cn zipformer recipe * disable speed perturbation by default * updated * updated * added necessary files for the zipformer recipe * removed redundant wenetspeech M and S sets * updates for multi dataset decoding * refined * formatting issues fixed * updated * minor fixes * this commit finalize the recipe (hopefully) * fixed formatting issues * minor fixes * updated * using soft links to reduce redundancy * minor updates * using soft links to reduce redundancy * minor updates * minor updates * using soft links to reduce redundancy * minor updates * Update README.md * minor updates * Update egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py Co-authored-by: Fangjun Kuang * Update egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py Co-authored-by: Fangjun Kuang * Update egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py Co-authored-by: Fangjun Kuang * Update egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py Co-authored-by: Fangjun Kuang * Update egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py Co-authored-by: Fangjun Kuang * Update egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py Co-authored-by: Fangjun Kuang * minor updates * minor fixes * fixed a formatting issue * Update preprocess_kespeech.py * Update prepare.sh * Update egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py Co-authored-by: Fangjun Kuang * Update egs/multi_zh-hans/ASR/local/preprocess_kespeech.py Co-authored-by: Fangjun Kuang * removed redundant files * symlinks added * minor updates * added CI tests for `multi_zh-hans` * minor fixes * Update run-multi-zh_hans-zipformer.sh * Update run-multi-zh_hans-zipformer.sh * Update run-multi-zh_hans-zipformer.sh * Update run-multi-zh_hans-zipformer.sh * Update run-multi-zh_hans-zipformer.sh * Update run-multi-zh_hans-zipformer.sh * Update run-multi-zh_hans-zipformer.sh --------- Co-authored-by: Fangjun Kuang --- .../scripts/run-multi-zh_hans-zipformer.sh | 51 + .../workflows/run-multi-zh_hans-zipformer.yml | 84 + egs/librispeech/ASR/zipformer/zipformer.py | 916 ++++++----- egs/multi_zh-hans/ASR/README.md | 39 + egs/multi_zh-hans/ASR/RESULTS.md | 38 + .../ASR/local/bpe_model_to_tokens.py | 37 + egs/multi_zh-hans/ASR/local/compile_lg.py | 1 + .../local/compute_fbank_kespeech_dev_test.py | 93 ++ .../local/compute_fbank_kespeech_splits.py | 180 +++ .../ASR/local/compute_fbank_magicdata.py | 122 ++ .../ASR/local/compute_fbank_primewords.py | 122 ++ .../ASR/local/compute_fbank_stcmds.py | 121 ++ .../ASR/local/compute_fbank_thchs30.py | 127 ++ egs/multi_zh-hans/ASR/local/prepare_char.py | 1 + .../ASR/local/prepare_for_bpe_model.py | 65 + egs/multi_zh-hans/ASR/local/prepare_lang.py | 1 + .../ASR/local/prepare_lang_bpe.py | 1 + .../ASR/local/preprocess_kespeech.py | 151 ++ egs/multi_zh-hans/ASR/local/text2token.py | 1 + .../ASR/local/train_bpe_model.py | 109 ++ .../ASR/local/validate_bpe_lexicon.py | 1 + egs/multi_zh-hans/ASR/prepare.sh | 373 +++++ egs/multi_zh-hans/ASR/shared | 1 + .../ASR/zipformer/asr_datamodule.py | 388 +++++ .../ASR/zipformer/beam_search.py | 1 + egs/multi_zh-hans/ASR/zipformer/decode.py | 828 ++++++++++ egs/multi_zh-hans/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + .../ASR/zipformer/export-onnx-streaming.py | 1 + .../ASR/zipformer/export-onnx.py | 1 + egs/multi_zh-hans/ASR/zipformer/export.py | 541 +++++++ .../ASR/zipformer/generate_averaged_model.py | 193 +++ .../ASR/zipformer/jit_pretrained.py | 1 + .../ASR/zipformer/jit_pretrained_ctc.py | 1 + .../ASR/zipformer/jit_pretrained_streaming.py | 1 + egs/multi_zh-hans/ASR/zipformer/joiner.py | 1 + egs/multi_zh-hans/ASR/zipformer/model.py | 1 + .../ASR/zipformer/multi_dataset.py | 316 ++++ egs/multi_zh-hans/ASR/zipformer/onnx_check.py | 1 + .../ASR/zipformer/onnx_decode.py | 1 + .../zipformer/onnx_pretrained-streaming.py | 1 + .../ASR/zipformer/onnx_pretrained.py | 1 + egs/multi_zh-hans/ASR/zipformer/optim.py | 1 + egs/multi_zh-hans/ASR/zipformer/pretrained.py | 381 +++++ egs/multi_zh-hans/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + .../ASR/zipformer/streaming_beam_search.py | 1 + .../ASR/zipformer/streaming_decode.py | 1 + .../ASR/zipformer/subsampling.py | 1 + egs/multi_zh-hans/ASR/zipformer/train.py | 1385 +++++++++++++++++ egs/multi_zh-hans/ASR/zipformer/zipformer.py | 1 + 51 files changed, 6319 insertions(+), 369 deletions(-) create mode 100755 .github/scripts/run-multi-zh_hans-zipformer.sh create mode 100644 .github/workflows/run-multi-zh_hans-zipformer.yml create mode 100644 egs/multi_zh-hans/ASR/README.md create mode 100644 egs/multi_zh-hans/ASR/RESULTS.md create mode 100755 egs/multi_zh-hans/ASR/local/bpe_model_to_tokens.py create mode 120000 egs/multi_zh-hans/ASR/local/compile_lg.py create mode 100755 egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py create mode 100755 egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py create mode 100755 egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py create mode 100755 egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py create mode 100755 egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py create mode 100755 egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py create mode 120000 egs/multi_zh-hans/ASR/local/prepare_char.py create mode 100755 egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py create mode 120000 egs/multi_zh-hans/ASR/local/prepare_lang.py create mode 120000 egs/multi_zh-hans/ASR/local/prepare_lang_bpe.py create mode 100755 egs/multi_zh-hans/ASR/local/preprocess_kespeech.py create mode 120000 egs/multi_zh-hans/ASR/local/text2token.py create mode 100755 egs/multi_zh-hans/ASR/local/train_bpe_model.py create mode 120000 egs/multi_zh-hans/ASR/local/validate_bpe_lexicon.py create mode 100755 egs/multi_zh-hans/ASR/prepare.sh create mode 120000 egs/multi_zh-hans/ASR/shared create mode 100644 egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/beam_search.py create mode 100755 egs/multi_zh-hans/ASR/zipformer/decode.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/decoder.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/encoder_interface.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/export-onnx.py create mode 100755 egs/multi_zh-hans/ASR/zipformer/export.py create mode 100755 egs/multi_zh-hans/ASR/zipformer/generate_averaged_model.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/jit_pretrained.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/jit_pretrained_ctc.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/jit_pretrained_streaming.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/joiner.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/model.py create mode 100644 egs/multi_zh-hans/ASR/zipformer/multi_dataset.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/onnx_check.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/onnx_decode.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/optim.py create mode 100755 egs/multi_zh-hans/ASR/zipformer/pretrained.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/scaling.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/scaling_converter.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/streaming_beam_search.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/streaming_decode.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/subsampling.py create mode 100755 egs/multi_zh-hans/ASR/zipformer/train.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/zipformer.py diff --git a/.github/scripts/run-multi-zh_hans-zipformer.sh b/.github/scripts/run-multi-zh_hans-zipformer.sh new file mode 100755 index 000000000..2bc3137d8 --- /dev/null +++ b/.github/scripts/run-multi-zh_hans-zipformer.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/multi_zh-hans/ASR + +repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s epoch-20.pt epoch-99.pt +popd + +ls -lh $repo/exp/*.pt + + +./zipformer/pretrained.py \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --method greedy_search \ +$repo/test_wavs/DEV_T0000000000.wav \ +$repo/test_wavs/DEV_T0000000001.wav \ +$repo/test_wavs/DEV_T0000000002.wav + +for method in modified_beam_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav +done diff --git a/.github/workflows/run-multi-zh_hans-zipformer.yml b/.github/workflows/run-multi-zh_hans-zipformer.yml new file mode 100644 index 000000000..4ec81585f --- /dev/null +++ b/.github/workflows/run-multi-zh_hans-zipformer.yml @@ -0,0 +1,84 @@ +# Copyright 2023 Xiaomi Corp. (author: Zengrui Jin) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-multi-zh_hans-zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +concurrency: + group: run_multi-zh_hans_zipformer-${{ github.ref }} + cancel-in-progress: true + +jobs: + run_multi-zh_hans_zipformer: + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2023-05-22 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-multi-zh_hans-zipformer.sh diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index b39af02b8..1a174b315 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -91,34 +91,34 @@ class Zipformer2(EncoderInterface): chunks. Must not be less than cnn_module_kernel (after factoring in rounding and downsampling); an error will be thrown if this is violated. """ + def __init__( - self, - output_downsampling_factor: int = 2, - downsampling_factor: Tuple[int] = (2, 4), - encoder_dim: Union[int, Tuple[int]] = 384, - num_encoder_layers: Union[int, Tuple[int]] = 4, - encoder_unmasked_dim: Union[int, Tuple[int]] = 256, - query_head_dim: Union[int, Tuple[int]] = 24, - pos_head_dim: Union[int, Tuple[int]] = 4, - value_head_dim: Union[int, Tuple[int]] = 12, - num_heads: Union[int, Tuple[int]] = 8, - feedforward_dim: Union[int, Tuple[int]] = 1536, - cnn_module_kernel: Union[int, Tuple[int]] = 31, - pos_dim: int = 192, - dropout: FloatLike = None, # see code below for default - warmup_batches: float = 4000.0, - causal: bool = False, - chunk_size: Tuple[int] = [-1], - left_context_frames: Tuple[int] = [-1], + self, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + encoder_unmasked_dim: Union[int, Tuple[int]] = 256, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_dim: Union[int, Tuple[int]] = 1536, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], ) -> None: super(Zipformer2, self).__init__() if dropout is None: - dropout = ScheduledFloat((0.0, 0.3), - (20000.0, 0.1)) + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) def _to_tuple(x): - """ Converts a single int or a 1-tuple of an int to a tuple with the same length + """Converts a single int or a 1-tuple of an int to a tuple with the same length as downsampling_factor""" if isinstance(x, int): x = (x,) @@ -128,10 +128,12 @@ class Zipformer2(EncoderInterface): assert len(x) == len(downsampling_factor) and isinstance(x[0], int) return x - self.output_downsampling_factor = output_downsampling_factor # int - self.downsampling_factor = downsampling_factor # tuple - self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple - self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( + encoder_unmasked_dim + ) # tuple num_encoder_layers = _to_tuple(num_encoder_layers) self.num_encoder_layers = num_encoder_layers self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) @@ -145,7 +147,7 @@ class Zipformer2(EncoderInterface): self.chunk_size = chunk_size self.left_context_frames = left_context_frames - for u,d in zip(encoder_unmasked_dim, encoder_dim): + for u, d in zip(encoder_unmasked_dim, encoder_dim): assert u <= d # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder @@ -153,7 +155,6 @@ class Zipformer2(EncoderInterface): num_encoders = len(downsampling_factor) for i in range(num_encoders): - encoder_layer = Zipformer2EncoderLayer( embed_dim=encoder_dim[i], pos_dim=pos_dim, @@ -191,13 +192,11 @@ class Zipformer2(EncoderInterface): self.encoders = nn.ModuleList(encoders) - self.downsample_output = SimpleDownsample(max(encoder_dim), - downsample=output_downsampling_factor, - dropout=dropout) + self.downsample_output = SimpleDownsample( + max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout + ) - def get_feature_masks( - self, - x: Tensor) -> Union[List[float], List[Tensor]]: + def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: """ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of randomized feature masks, one per encoder. @@ -215,24 +214,30 @@ class Zipformer2(EncoderInterface): """ num_encoders = len(self.encoder_dim) if not self.training: - return [ 1.0 ] * num_encoders + return [1.0] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape - assert self.encoder_dim[0] == _encoder_dims0, (self.encoder_dim[0], _encoder_dims0) + assert self.encoder_dim[0] == _encoder_dims0, ( + self.encoder_dim[0], + _encoder_dims0, + ) feature_mask_dropout_prob = 0.125 # mask1 shape: (1, batch_size, 1) - mask1 = (torch.rand(1, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype) + mask1 = ( + torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob + ).to(x.dtype) # mask2 has additional sequences masked, about twice the number. - mask2 = torch.logical_and(mask1, - (torch.rand(1, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype)) + mask2 = torch.logical_and( + mask1, + ( + torch.rand(1, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype), + ) # dim: (1, batch_size, 2) mask = torch.cat((mask1, mask2), dim=-1) @@ -240,8 +245,9 @@ class Zipformer2(EncoderInterface): feature_masks = [] for i in range(num_encoders): channels = self.encoder_dim[i] - feature_mask = torch.ones(1, batch_size, channels, - dtype=x.dtype, device=x.device) + feature_mask = torch.ones( + 1, batch_size, channels, dtype=x.dtype, device=x.device + ) u1 = self.encoder_unmasked_dim[i] u2 = u1 + (channels - u1) // 2 @@ -281,7 +287,8 @@ class Zipformer2(EncoderInterface): return chunk_size, left_context_chunks def forward( - self, x: Tensor, + self, + x: Tensor, x_lens: Tensor, src_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: @@ -319,12 +326,17 @@ class Zipformer2(EncoderInterface): ds = self.downsampling_factor[i] x = convert_num_channels(x, self.encoder_dim[i]) - x = module(x, - chunk_size=chunk_size, - feature_mask=feature_masks[i], - src_key_padding_mask=(None if src_key_padding_mask is None - else src_key_padding_mask[...,::ds]), - attn_mask=attn_mask) + x = module( + x, + chunk_size=chunk_size, + feature_mask=feature_masks[i], + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=attn_mask, + ) outputs.append(x) # if the last output has the largest dimension, x will be unchanged, @@ -345,9 +357,7 @@ class Zipformer2(EncoderInterface): return x, lengths def _get_attn_mask( - self, x: Tensor, - chunk_size: int, - left_context_chunks: int + self, x: Tensor, chunk_size: int, left_context_chunks: int ) -> Optional[Tensor]: """ Return None if chunk_size == -1, else return attention mask of shape @@ -362,9 +372,11 @@ class Zipformer2(EncoderInterface): assert all(chunk_size % d == 0 for d in self.downsampling_factor) if left_context_chunks >= 0: num_encoders = len(self.encoder_dim) - assert all (chunk_size * left_context_chunks >= - (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] - for i in range(num_encoders)) + assert all( + chunk_size * left_context_chunks + >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) else: left_context_chunks = 1000000 @@ -382,8 +394,7 @@ class Zipformer2(EncoderInterface): src_c = c tgt_c = c.unsqueeze(-1) - attn_mask = torch.logical_or(src_c > tgt_c, - src_c < tgt_c - left_context_chunks) + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) if __name__ == "__main__": logging.info(f"attn_mask = {attn_mask}") return attn_mask @@ -392,7 +403,7 @@ class Zipformer2(EncoderInterface): num_encoders = len(self.encoder_dim) assert len(outputs) == num_encoders output_dim = max(self.encoder_dim) - output_pieces = [ outputs[-1] ] + output_pieces = [outputs[-1]] cur_dim = self.encoder_dim[-1] for i in range(num_encoders - 2, -1, -1): d = self.encoder_dim[i] @@ -489,21 +500,38 @@ class Zipformer2(EncoderInterface): nonlin_attn_head_dim = 3 * embed_dim // 4 conv_left_pad = self.cnn_module_kernel[i] // 2 for layer in range(num_layers): - cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(device) - cached_nonlin_attn = torch.zeros(1, batch_size, downsample_left, nonlin_attn_head_dim).to(device) - cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(device) - cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(device) - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(device) - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(device) - states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2] + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( + device + ) + cached_nonlin_attn = torch.zeros( + 1, batch_size, downsample_left, nonlin_attn_head_dim + ).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] return states def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: - return ScheduledFloat((0.0, x), - (20000.0, ratio * x), - default=x) + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) def _balancer_schedule(min_prob: float): @@ -525,31 +553,45 @@ class Zipformer2EncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ + def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - value_head_dim: int, - feedforward_dim: int, - dropout: FloatLike = 0.1, - cnn_module_kernel: int = 31, - causal: bool = False, - attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), - ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), - ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), - bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + causal: bool = False, + attention_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + conv_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + const_attention_rate: FloatLike = ScheduledFloat( + (0.0, 0.25), (4000.0, 0.025), default=0 + ), + ff2_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + ff3_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + bypass_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.5), (4000.0, 0.02), default=0 + ), ) -> None: super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim # self.bypass implements layer skipping as well as bypass; see its default values. - self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate, - straight_through_rate=0) + self.bypass = BypassModule( + embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 + ) # bypass_mid is bypass used in the middle of the layer. self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) @@ -567,39 +609,39 @@ class Zipformer2EncoderLayer(nn.Module): self.const_attention_rate = copy.deepcopy(const_attention_rate) self.self_attn_weights = RelPositionMultiheadAttentionWeights( - embed_dim, pos_dim=pos_dim, num_heads=num_heads, - query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, dropout=0.0, ) - self.self_attn1 = SelfAttention(embed_dim, num_heads, - value_head_dim) + self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) - self.self_attn2 = SelfAttention(embed_dim, num_heads, - value_head_dim) + self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) - self.feed_forward1 = FeedforwardModule(embed_dim, - (feedforward_dim * 3) // 4, - dropout) + self.feed_forward1 = FeedforwardModule( + embed_dim, (feedforward_dim * 3) // 4, dropout + ) - self.feed_forward2 = FeedforwardModule(embed_dim, - feedforward_dim, - dropout) + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule(embed_dim, - (feedforward_dim * 5) // 4, - dropout) + self.feed_forward3 = FeedforwardModule( + embed_dim, (feedforward_dim * 5) // 4, dropout + ) - self.nonlin_attention = NonlinAttention(embed_dim, - hidden_channels=3 * embed_dim // 4) + self.nonlin_attention = NonlinAttention( + embed_dim, hidden_channels=3 * embed_dim // 4 + ) - self.conv_module1 = ConvolutionModule(embed_dim, - cnn_module_kernel, - causal=causal) + self.conv_module1 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) - self.conv_module2 = ConvolutionModule(embed_dim, - cnn_module_kernel, - causal=causal) + self.conv_module2 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) # TODO: remove it self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) @@ -607,15 +649,20 @@ class Zipformer2EncoderLayer(nn.Module): self.norm = BiasNorm(embed_dim) self.balancer1 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.45, max_positive=0.55, - min_abs=0.2, max_abs=4.0, + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.2, + max_abs=4.0, ) # balancer for output of NonlinAttentionModule self.balancer_na = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), prob=0.05, # out of concern for memory usage ) @@ -624,34 +671,50 @@ class Zipformer2EncoderLayer(nn.Module): # small. give this a very small probability, even at the start of # training, it's to fix a rare problem and it's OK to fix it slowly. self.balancer_ff2 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), max_abs=2.0, prob=0.05, ) self.balancer_ff3 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), max_abs=4.0, prob=0.05, ) - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(4.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - self.balancer2 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.45, max_positive=0.55, - min_abs=0.1, max_abs=4.0, + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, ) - def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]: - if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): + self.balancer2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.1, + max_abs=4.0, + ) + + def get_sequence_dropout_mask( + self, x: Tensor, dropout_rate: float + ) -> Optional[Tensor]: + if ( + dropout_rate == 0.0 + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): return None batch_size = x.shape[1] mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) @@ -677,21 +740,21 @@ class Zipformer2EncoderLayer(nn.Module): src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. - Returns: - A tensor which has the same shape as src + Returns: + A tensor which has the same shape as src """ src_orig = src @@ -699,7 +762,9 @@ class Zipformer2EncoderLayer(nn.Module): if torch.jit.is_scripting() or torch.jit.is_tracing(): attention_skip_rate = 0.0 else: - attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 + attention_skip_rate = ( + float(self.attention_skip_rate) if self.training else 0.0 + ) # attn_weights: (num_heads, batch_size, seq_len, seq_len) attn_weights = self.self_attn_weights( @@ -711,7 +776,9 @@ class Zipformer2EncoderLayer(nn.Module): src = src + self.feed_forward1(src) - self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) + self_attn_dropout_mask = self.get_sequence_dropout_mask( + src, attention_skip_rate + ) selected_attn_weights = attn_weights[0:1] if torch.jit.is_scripting() or torch.jit.is_tracing(): @@ -722,53 +789,75 @@ class Zipformer2EncoderLayer(nn.Module): # averaging-over-time operation. # only need the mask, can just use the 1st one and expand later selected_attn_weights = selected_attn_weights[0:1] - selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) - selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) + selected_attn_weights = (selected_attn_weights > 0.0).to( + selected_attn_weights.dtype + ) + selected_attn_weights = selected_attn_weights * ( + 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) + ) na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) - src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) + src = src + ( + na if self_attn_dropout_mask is None else na * self_attn_dropout_mask + ) self_attn = self.self_attn1(src, attn_weights) - src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) if torch.jit.is_scripting() or torch.jit.is_tracing(): conv_skip_rate = 0.0 else: conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size, - src_key_padding_mask=src_key_padding_mask), - conv_skip_rate) + src = src + self.sequence_dropout( + self.conv_module1( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) if torch.jit.is_scripting() or torch.jit.is_tracing(): ff2_skip_rate = 0.0 else: ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)), - ff2_skip_rate) + src = src + self.sequence_dropout( + self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate + ) # bypass in the middle of the layer. src = self.bypass_mid(src_orig, src) self_attn = self.self_attn2(src, attn_weights) - src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) if torch.jit.is_scripting() or torch.jit.is_tracing(): conv_skip_rate = 0.0 else: conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size, - src_key_padding_mask=src_key_padding_mask), - conv_skip_rate) + src = src + self.sequence_dropout( + self.conv_module2( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) if torch.jit.is_scripting() or torch.jit.is_tracing(): ff3_skip_rate = 0.0 else: ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)), - ff3_skip_rate) + src = src + self.sequence_dropout( + self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate + ) src = self.balancer1(src) src = self.norm(src) @@ -912,20 +1001,22 @@ class Zipformer2Encoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ + def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - pos_dim: int, - dropout: float, - warmup_begin: float, - warmup_end: float, - initial_layerdrop_rate: float = 0.5, - final_layerdrop_rate: float = 0.05, + self, + encoder_layer: nn.Module, + num_layers: int, + pos_dim: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, ) -> None: super().__init__() - self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, - length_factor=1.0) + self.encoder_pos = CompactRelPositionalEncoding( + pos_dim, dropout_rate=0.15, length_factor=1.0 + ) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -934,13 +1025,15 @@ class Zipformer2Encoder(nn.Module): assert 0 <= warmup_begin <= warmup_end - delta = (1. / num_layers) * (warmup_end - warmup_begin) + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) cur_begin = warmup_begin # interpreted as a training batch index for i in range(num_layers): cur_end = cur_begin + delta - self.layers[i].bypass.skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate), - (cur_end, final_layerdrop_rate), - default=0.0) + self.layers[i].bypass.skip_rate = ScheduledFloat( + (cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0, + ) cur_begin = cur_end def forward( @@ -1014,8 +1107,13 @@ class Zipformer2Encoder(nn.Module): new_states = [] for i, mod in enumerate(self.layers): ( - cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2 - ) = states[i * 6: (i + 1) * 6] + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) = states[i * 6 : (i + 1) * 6] ( output, new_cached_key, @@ -1023,7 +1121,7 @@ class Zipformer2Encoder(nn.Module): new_cached_val1, new_cached_val2, new_cached_conv1, - new_cached_conv2 + new_cached_conv2, ) = mod.streaming_forward( output, pos_emb, @@ -1055,13 +1153,15 @@ class BypassModule(nn.Module): "straight-through", i.e. to not do the bypass operation much initially, in order to force all the modules to learn something. """ + def __init__( - self, - embed_dim: int, - skip_rate: FloatLike = 0.0, - straight_through_rate: FloatLike = 0.0, - scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), - scale_max: FloatLike = 1.0): + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0, + ): super().__init__() self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) self.skip_rate = copy.deepcopy(skip_rate) @@ -1077,9 +1177,9 @@ class BypassModule(nn.Module): if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: return self.bypass_scale else: - ans = limit_param_value(self.bypass_scale, - min=float(self.scale_min), - max=float(self.scale_max)) + ans = limit_param_value( + self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) + ) skip_rate = float(self.skip_rate) if skip_rate != 0.0: mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate @@ -1088,13 +1188,14 @@ class BypassModule(nn.Module): # on which we have randomly chosen to do layer-skipping. straight_through_rate = float(self.straight_through_rate) if straight_through_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) < straight_through_rate + mask = ( + torch.rand((batch_size, 1), device=ans.device) + < straight_through_rate + ) ans = torch.maximum(ans, mask.to(ans.dtype)) return ans - def forward(self, - src_orig: Tensor, - src: Tensor): + def forward(self, src_orig: Tensor, src: Tensor): """ Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) Returns: something with the same shape as src and src_orig @@ -1109,15 +1210,13 @@ class DownsampledZipformer2Encoder(nn.Module): after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ - def __init__(self, - encoder: nn.Module, - dim: int, - downsample: int, - dropout: FloatLike): + + def __init__( + self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike + ): super(DownsampledZipformer2Encoder, self).__init__() self.downsample_factor = downsample - self.downsample = SimpleDownsample(dim, - downsample, dropout) + self.downsample = SimpleDownsample(dim, downsample, dropout) self.num_layers = encoder.num_layers self.encoder = encoder self.upsample = SimpleUpsample(dim, downsample) @@ -1149,7 +1248,7 @@ class DownsampledZipformer2Encoder(nn.Module): src = self.downsample(src) ds = self.downsample_factor if attn_mask is not None: - attn_mask = attn_mask[::ds,::ds] + attn_mask = attn_mask[::ds, ::ds] src = self.encoder( src, @@ -1160,7 +1259,7 @@ class DownsampledZipformer2Encoder(nn.Module): ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] + src = src[: src_orig.shape[0]] return self.out_combiner(src_orig, src) @@ -1196,7 +1295,7 @@ class DownsampledZipformer2Encoder(nn.Module): ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] + src = src[: src_orig.shape[0]] return self.out_combiner(src_orig, src), new_states @@ -1205,10 +1304,8 @@ class SimpleDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ - def __init__(self, - channels: int, - downsample: int, - dropout: FloatLike): + + def __init__(self, channels: int, downsample: int, dropout: FloatLike): super(SimpleDownsample, self).__init__() self.bias = nn.Parameter(torch.zeros(downsample)) @@ -1218,8 +1315,7 @@ class SimpleDownsample(torch.nn.Module): self.downsample = downsample - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, in_channels) Returns a tensor of shape @@ -1232,7 +1328,7 @@ class SimpleDownsample(torch.nn.Module): # Pad to an exact multiple of self.downsample # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds @@ -1253,14 +1349,12 @@ class SimpleUpsample(torch.nn.Module): A very simple form of upsampling that mostly just repeats the input, but also adds a position-specific bias. """ - def __init__(self, - num_channels: int, - upsample: int): + + def __init__(self, num_channels: int, upsample: int): super(SimpleUpsample, self).__init__() self.upsample = upsample - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, num_channels) Returns a tensor of shape @@ -1298,11 +1392,13 @@ class CompactRelPositionalEncoding(torch.nn.Module): length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives less weight to small differences of offset near the origin. """ + def __init__( - self, embed_dim: int, - dropout_rate: FloatLike, - max_len: int = 1000, - length_factor: float = 1.0, + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() @@ -1326,19 +1422,22 @@ class CompactRelPositionalEncoding(torch.nn.Module): return # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] - x = torch.arange(-(T-1), T, - device=x.device).to(torch.float32).unsqueeze(1) + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution # for small time offsets but less resolution for large time offsets. - compression_length = (self.embed_dim ** 0.5) + compression_length = self.embed_dim**0.5 # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; # but it does so more slowly than T for large absolute values of T. # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which # is important. - x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length)) + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) # if self.length_factor == 1.0, then length_scale is chosen so that the # FFT can exactly separate points close to the origin (T == 0). So this @@ -1380,7 +1479,7 @@ class CompactRelPositionalEncoding(torch.nn.Module): - x_size_left + 1 : self.pe.size(0) // 2 # noqa E203 + x.size(0), - : + :, ] pos_emb = pos_emb.unsqueeze(0) return self.dropout(pos_emb) @@ -1407,15 +1506,14 @@ class RelPositionMultiheadAttentionWeights(nn.Module): """ def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - dropout: float = 0.0, - pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), - (4000.0, 0.0)) + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), ) -> None: super().__init__() self.embed_dim = embed_dim @@ -1434,13 +1532,16 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # dividing it between the query and key. Note: this module is intended # to be used with the ScaledAdam optimizer; with most other optimizers, # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, - initial_scale=query_head_dim**-0.25) + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + ) - self.whiten_keys = Whiten(num_groups=num_heads, - whitening_limit=_whitening_schedule(3.0), - prob=(0.025, 0.25), - grad_scale=0.025) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) # add a balancer for the keys that runs with very small probability, and # tries to enforce that all dimensions have mean around zero. The @@ -1450,19 +1551,20 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # bias because the small numerical roundoff tends to have a non-random # sign. This module is intended to prevent that. Use a very small # probability; that should be suffixient to fix the problem. - self.balance_keys = Balancer(key_head_dim * num_heads, - channel_dim=-1, - min_positive=0.4, - max_positive=0.6, - min_abs=0.0, - max_abs=100.0, - prob=0.025) + self.balance_keys = Balancer( + key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025, + ) # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(pos_dim, - num_heads * pos_head_dim, - bias=False, - initial_scale=0.05) + self.linear_pos = ScaledLinear( + pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 + ) # the following are for diagnosics only, see --print-diagnostics option self.copy_pos_query = Identity() @@ -1498,10 +1600,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module): query_dim = query_head_dim * num_heads # self-attention - q = x[...,0:query_dim] - k = x[...,query_dim:2*query_dim] + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] # p is the position-encoding query - p = x[...,2*query_dim:] + p = x[..., 2 * query_dim :] assert p.shape[-1] == num_heads * pos_head_dim q = self.copy_query(q) # for diagnostics only, does nothing. @@ -1529,7 +1631,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if use_pos_scores: pos_emb = self.linear_pos(pos_emb) seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) @@ -1548,12 +1652,16 @@ class RelPositionMultiheadAttentionWeights(nn.Module): pos_scores = torch.gather(pos_scores, dim=1, index=indexes) pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) else: - pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len), - (pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2)-pos_scores.stride(3), - pos_scores.stride(3)), - storage_offset=pos_scores.stride(3) * (seq_len - 1)) + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) attn_scores = attn_scores + pos_scores @@ -1572,10 +1680,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # but we view this as a failsafe to avoid "implausible" parameter # values rather than a regularization method that should be active # under normal circumstances. - attn_scores = penalize_abs_values_gt(attn_scores, - limit=25.0, - penalty=1.0e-04, - name=self.name) + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + ) assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) @@ -1588,7 +1695,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = attn_scores.masked_fill(attn_mask, -1000) if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape attn_scores = attn_scores.masked_fill( key_padding_mask.unsqueeze(1), -1000, @@ -1644,14 +1754,17 @@ class RelPositionMultiheadAttentionWeights(nn.Module): query_dim = query_head_dim * num_heads # self-attention - q = x[...,0:query_dim] - k = x[...,query_dim:2*query_dim] + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] # p is the position-encoding query - p = x[...,2*query_dim:] + p = x[..., 2 * query_dim :] assert p.shape[-1] == num_heads * pos_head_dim # Pad cached left contexts - assert cached_key.shape[0] == left_context_len, (cached_key.shape[0], left_context_len) + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape[0], + left_context_len, + ) k = torch.cat([cached_key, k], dim=0) # Update cached left contexts cached_key = k[-left_context_len:, ...] @@ -1672,13 +1785,15 @@ class RelPositionMultiheadAttentionWeights(nn.Module): pos_emb = self.linear_pos(pos_emb) seq_len2 = 2 * seq_len - 1 + left_context_len - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # [where seq_len2 represents relative position.] pos_scores = torch.matmul(p, pos_emb) - + if torch.jit.is_tracing(): (num_heads, batch_size, time1, n) = pos_scores.shape rows = torch.arange(start=time1 - 1, end=-1, step=-1) @@ -1692,16 +1807,25 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. else: - pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len), - (pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2)-pos_scores.stride(3), - pos_scores.stride(3)), - storage_offset=pos_scores.stride(3) * (seq_len - 1)) + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, k_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) attn_scores = attn_scores + pos_scores - assert attn_scores.shape == (num_heads, batch_size, seq_len, k_len), attn_scores.shape + assert attn_scores.shape == ( + num_heads, + batch_size, + seq_len, + k_len, + ), attn_scores.shape if key_padding_mask is not None: assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape @@ -1714,18 +1838,21 @@ class RelPositionMultiheadAttentionWeights(nn.Module): return attn_weights, cached_key - def _print_attn_entropy( - self, - attn_weights: Tensor): + def _print_attn_entropy(self, attn_weights: Tensor): # attn_weights: (num_heads, batch_size, seq_len, seq_len) (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) - attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( - dim=-1).mean(dim=(1,2)) - logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}") + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) class SelfAttention(nn.Module): @@ -1738,25 +1865,26 @@ class SelfAttention(nn.Module): num_heads: the number of attention heads value_head_dim: the value dimension per head """ + def __init__( - self, - embed_dim: int, - num_heads: int, - value_head_dim: int, + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, ) -> None: super().__init__() - self.in_proj = nn.Linear(embed_dim, - num_heads * value_head_dim, - bias=True) + self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) - self.out_proj = ScaledLinear(num_heads * value_head_dim, - embed_dim, bias=True, - initial_scale=0.05) + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) def forward( self, @@ -1785,8 +1913,11 @@ class SelfAttention(nn.Module): x = torch.matmul(attn_weights, x) # v: (num_heads, batch_size, seq_len, value_head_dim) - x = x.permute(2, 1, 0, 3).contiguous().view( - seq_len, batch_size, num_heads * value_head_dim) + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) # returned value is of shape (seq_len, batch_size, embed_dim), like the input. x = self.out_proj(x) @@ -1823,7 +1954,10 @@ class SelfAttention(nn.Module): x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) # Pad cached left contexts - assert cached_val.shape[0] == left_context_len, (cached_val.shape[0], left_context_len) + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) x = torch.cat([cached_val, x], dim=0) # Update cached left contexts cached_val = x[-left_context_len:, ...] @@ -1836,8 +1970,11 @@ class SelfAttention(nn.Module): x = torch.matmul(attn_weights, x) # v: (num_heads, batch_size, seq_len, value_head_dim) - x = x.permute(2, 1, 0, 3).contiguous().view( - seq_len, batch_size, num_heads * value_head_dim) + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) # returned value is of shape (seq_len, batch_size, embed_dim), like the input. x = self.out_proj(x) @@ -1846,33 +1983,38 @@ class SelfAttention(nn.Module): class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer2 model. - """ - def __init__(self, - embed_dim: int, - feedforward_dim: int, - dropout: FloatLike): + """Feedforward module in Zipformer2 model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(embed_dim, feedforward_dim) - self.hidden_balancer = Balancer(feedforward_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0) + self.hidden_balancer = Balancer( + feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear(feedforward_dim, embed_dim, - activation='SwooshL', - dropout_p=dropout, - dropout_shared_dim=0, bias=True, - initial_scale=0.1) + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwooshL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.1, + ) - self.out_whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01) + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) def forward(self, x: Tensor): x = self.in_proj(x) @@ -1893,9 +2035,9 @@ class NonlinAttention(nn.Module): """ def __init__( - self, - channels: int, - hidden_channels: int, + self, + channels: int, + hidden_channels: int, ) -> None: super().__init__() @@ -1908,7 +2050,8 @@ class NonlinAttention(nn.Module): # starting from about 3, and poorly-trained instances of the module have smaller abs values # before the sigmoid. self.balancer = Balancer( - hidden_channels, channel_dim=-1, + hidden_channels, + channel_dim=-1, min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), min_abs=0.5, @@ -1920,19 +2063,23 @@ class NonlinAttention(nn.Module): self.identity2 = Identity() # for diagnostics. self.identity3 = Identity() # for diagnostics. - self.out_proj = ScaledLinear(hidden_channels, channels, - bias=True, - initial_scale=0.05) + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) - self.whiten1 = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) - self.whiten2 = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(5.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) def forward( self, @@ -1940,11 +2087,11 @@ class NonlinAttention(nn.Module): attn_weights: Tensor, ) -> Tensor: """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) -attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - Returns: - a Tensor with the same shape as x + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x """ x = self.in_proj(x) @@ -2014,13 +2161,21 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) (seq_len, batch_size, embed_dim) = x.shape num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, left_context_len + seq_len) + assert attn_weights.shape == ( + num_heads, + batch_size, + seq_len, + left_context_len + seq_len, + ) x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) # now x: (num_heads, batch_size, seq_len, head_dim) # Pad cached tensor - assert cached_x.shape[2] == left_context_len, (cached_x.shape[2], left_context_len) + assert cached_x.shape[2] == left_context_len, ( + cached_x.shape[2], + left_context_len, + ) x_pad = torch.cat([cached_x, x], dim=2) # Update cached tensor cached_x = x_pad[:, :, -left_context_len:, :] @@ -2045,8 +2200,12 @@ class ConvolutionModule(nn.Module): bias (bool): Whether to use bias in conv layers (default=True). """ + def __init__( - self, channels: int, kernel_size: int, causal: bool, + self, + channels: int, + kernel_size: int, + causal: bool, ) -> None: """Construct a ConvolutionModule object.""" super(ConvolutionModule, self).__init__() @@ -2057,7 +2216,8 @@ class ConvolutionModule(nn.Module): self.causal = causal self.in_proj = nn.Linear( - channels, 2 * bottleneck_dim, + channels, + 2 * bottleneck_dim, ) # the gradients on in_proj are a little noisy, likely to do with the # sigmoid in glu. @@ -2076,7 +2236,8 @@ class ConvolutionModule(nn.Module): # it will be in a better position to start learning something, i.e. to latch onto # the correct range. self.balancer1 = Balancer( - bottleneck_dim, channel_dim=-1, + bottleneck_dim, + channel_dim=-1, min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), max_positive=1.0, min_abs=1.5, @@ -2091,31 +2252,40 @@ class ConvolutionModule(nn.Module): assert kernel_size % 2 == 1 - self.depthwise_conv = ChunkCausalDepthwiseConv1d( - channels=bottleneck_dim, - kernel_size=kernel_size) if causal else nn.Conv1d( - in_channels=bottleneck_dim, - out_channels=bottleneck_dim, - groups=bottleneck_dim, - kernel_size=kernel_size, - padding=kernel_size // 2) + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + if causal + else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) self.balancer2 = Balancer( - bottleneck_dim, channel_dim=1, + bottleneck_dim, + channel_dim=1, min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), max_positive=1.0, min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), max_abs=10.0, ) - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) self.out_proj = ActivationDropoutAndLinear( - bottleneck_dim, channels, activation='SwooshR', - dropout_p=0.0, initial_scale=0.05, + bottleneck_dim, + channels, + activation="SwooshR", + dropout_p=0.0, + initial_scale=0.05, ) def forward( @@ -2153,9 +2323,15 @@ class ConvolutionModule(nn.Module): if src_key_padding_mask is not None: x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0: + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and chunk_size >= 0 + ): # Not support exporting a model for simulated streaming decoding - assert self.causal, "Must initialize model with causal=True if you use chunk_size" + assert ( + self.causal + ), "Must initialize model with causal=True if you use chunk_size" x = self.depthwise_conv(x, chunk_size=chunk_size) else: x = self.depthwise_conv(x) @@ -2225,10 +2401,12 @@ def _test_zipformer_main(causal: bool = False): # Just make sure the forward pass runs. c = Zipformer2( - encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), + encoder_dim=(64, 96), + encoder_unmasked_dim=(48, 64), + num_heads=(4, 4), causal=causal, chunk_size=(4,) if causal else (-1,), - left_context_frames=(64,) + left_context_frames=(64,), ) batch_size = 5 seq_len = 20 diff --git a/egs/multi_zh-hans/ASR/README.md b/egs/multi_zh-hans/ASR/README.md new file mode 100644 index 000000000..537816a5d --- /dev/null +++ b/egs/multi_zh-hans/ASR/README.md @@ -0,0 +1,39 @@ + +# Introduction + +This recipe includes scripts for training Zipformer model using multiple Chinese datasets. + +# Included Training Sets +1. THCHS-30 +2. AiShell-{1,2,4} +3. ST-CMDS +4. Primewords +5. MagicData +6. Aidatatang_200zh +7. AliMeeting +8. WeNetSpeech +9. KeSpeech-ASR + +|Datset| Number of hours| URL| +|---|---:|---| +|**TOTAL**|14,106|---| +|THCHS-30|35|https://www.openslr.org/18/| +|AiShell-1|170|https://www.openslr.org/33/| +|AiShell-2|1,000|http://www.aishelltech.com/aishell_2| +|AiShell-4|120|https://www.openslr.org/111/| +|ST-CMDS|110|https://www.openslr.org/38/| +|Primewords|99|https://www.openslr.org/47/| +|aidatatang_200zh|200|https://www.openslr.org/62/| +|MagicData|755|https://www.openslr.org/68/| +|AliMeeting|100|https://openslr.org/119/| +|WeNetSpeech|10,000|https://github.com/wenet-e2e/WenetSpeech| +|KeSpeech|1,542|https://github.com/KeSpeech/KeSpeech| + + +# Included Test Sets +1. Aishell-{1,2,4} +2. Aidatatang_200zh +3. AliMeeting +4. MagicData +5. KeSpeech-ASR +6. WeNetSpeech \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/RESULTS.md b/egs/multi_zh-hans/ASR/RESULTS.md new file mode 100644 index 000000000..31fbd9700 --- /dev/null +++ b/egs/multi_zh-hans/ASR/RESULTS.md @@ -0,0 +1,38 @@ +## Results + +### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model + +This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall. + +#### Non-streaming + +Best results (num of params : ~69M): + +The training command: + +``` +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 20 \ + --use-fp16 1 \ + --max-duration 600 \ + --num-workers 8 +``` + +The decoding command: + +``` +./zipformer/decode.py \ + --epoch 20 \ + --avg 1 +``` + +Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled). + +| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | +|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| +| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | +| | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 | + + +The pre-trained model is available here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2 diff --git a/egs/multi_zh-hans/ASR/local/bpe_model_to_tokens.py b/egs/multi_zh-hans/ASR/local/bpe_model_to_tokens.py new file mode 100755 index 000000000..d078e5b98 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/bpe_model_to_tokens.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 + +""" +This script takes `bpe.model` as input and generates a file `tokens.txt` +from it. + +Usage: +./bpe_model_to_tokens.py /path/to/input/bpe.model > tokens.txt +""" +import argparse + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "bpe_model", + type=str, + help="Path to the input bpe.model", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + for i in range(sp.vocab_size()): + print(sp.id_to_piece(i), i) + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/local/compile_lg.py b/egs/multi_zh-hans/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py new file mode 100755 index 000000000..2581ee42f --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# Copyright 2021 Johns Hopkins University (Piotr Żelasko) +# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) +# Copyright 2023 Xiaomi Corp. (Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +import torch +from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_kespeech_dev_test(): + in_out_dir = Path("data/fbank/kespeech") + # number of workers in dataloader + num_workers = 42 + + # number of seconds in a batch + batch_duration = 600 + + subsets = ( + "dev_phase1", + "dev_phase2", + "test", + ) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + + logging.info(f"device: {device}") + + for partition in subsets: + cuts_path = in_out_dir / f"kespeech-asr_cuts_{partition}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + raw_cuts_path = in_out_dir / f"kespeech-asr_cuts_{partition}_raw.jsonl.gz" + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Splitting cuts into smaller chunks") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info("Computing features") + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{in_out_dir}/feats_{partition}", + num_workers=num_workers, + batch_duration=batch_duration, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_fbank_kespeech_dev_test() + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py new file mode 100755 index 000000000..8bfbc7b50 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# Copyright 2021 Johns Hopkins University (Piotr Żelasko) +# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) +# Copyright 2023 Xiaomi Corp. (Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from datetime import datetime +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomChunkyWriter, + set_audio_duration_mismatch_tolerance, + set_caching_enabled, +) + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--training-subset", + type=str, + default="train_phase1", + choices=["train_phase1", "train_phase2"], + help="The training subset for computing fbank feature.", + ) + + parser.add_argument( + "--num-workers", + type=int, + default=20, + help="Number of dataloading workers used for reading the audio.", + ) + + parser.add_argument( + "--batch-duration", + type=float, + default=600.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + + parser.add_argument( + "--num-splits", + type=int, + required=True, + help="The number of splits of the given subset", + ) + + parser.add_argument( + "--start", + type=int, + default=0, + help="Process pieces starting from this number (inclusive).", + ) + + parser.add_argument( + "--stop", + type=int, + default=-1, + help="Stop processing pieces until this number (exclusive).", + ) + return parser + + +def compute_fbank_kespeech_splits(args): + subset = args.training_subset + subset = str(subset) + num_splits = args.num_splits + output_dir = f"data/fbank/kespeech/{subset}_split_{num_splits}" + output_dir = Path(output_dir) + assert output_dir.exists(), f"{output_dir} does not exist!" + + num_digits = len(str(num_splits)) + + start = args.start + stop = args.stop + if stop < start: + stop = num_splits + + stop = min(stop, num_splits) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + logging.info(f"device: {device}") + + set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance + set_caching_enabled(False) + for i in range(start, stop): + idx = f"{i + 1}".zfill(num_digits) + logging.info(f"Processing {idx}/{num_splits}") + + cuts_path = output_dir / f"kespeech-asr_cuts_{subset}.{idx}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + raw_cuts_path = output_dir / f"kespeech-asr_cuts_{subset}_raw.{idx}.jsonl.gz" + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Splitting cuts into smaller chunks.") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info("Computing features") + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/feats_{subset}_{idx}", + num_workers=args.num_workers, + batch_duration=args.batch_duration, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + + +def main(): + now = datetime.now() + date_time = now.strftime("%Y-%m-%d-%H-%M-%S") + + log_filename = "log-compute_fbank_kespeech_splits" + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + log_filename = f"{log_filename}-{date_time}" + + logging.basicConfig( + filename=log_filename, + format=formatter, + level=logging.INFO, + filemode="w", + ) + + console = logging.StreamHandler() + console.setLevel(logging.INFO) + console.setFormatter(logging.Formatter(formatter)) + logging.getLogger("").addHandler(console) + + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + compute_fbank_kespeech_splits(args) + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py b/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py new file mode 100755 index 000000000..5649d3815 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_magicdata.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang +# Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the MagicData dataset. +It looks for manifests in the directory data/manifests/magicdata. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False): + src_dir = Path("data/manifests/magicdata") + output_dir = Path("data/fbank") + num_jobs = min(30, os.cpu_count()) + + dataset_parts = ("train", "test", "dev") + prefix = "magicdata" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if "train" in partition and speed_perturb: + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + compute_fbank_magicdata( + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py b/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py new file mode 100755 index 000000000..303a16580 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_primewords.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang +# Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the Primewords dataset. +It looks for manifests in the directory data/manifests/primewords. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False): + src_dir = Path("data/manifests/primewords") + output_dir = Path("data/fbank") + num_jobs = min(15, os.cpu_count()) + + dataset_parts = ("train",) + prefix = "primewords" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if "train" in partition and speed_perturb: + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + compute_fbank_primewords( + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py b/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py new file mode 100755 index 000000000..730806954 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_stcmds.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang +# Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the ST-CMDS dataset. +It looks for manifests in the directory data/manifests/stcmds. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False): + src_dir = Path("data/manifests/stcmds") + output_dir = Path("data/fbank") + num_jobs = min(15, os.cpu_count()) + + dataset_parts = ("train",) + prefix = "stcmds" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if "train" in partition and speed_perturb: + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + compute_fbank_stcmds( + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + ) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py b/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py new file mode 100755 index 000000000..58bb8002a --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_thchs30.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang +# Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the THCHS-30 dataset. +It looks for manifests in the directory data/manifests/thchs30. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False): + src_dir = Path("data/manifests/thchs30") + output_dir = Path("data/fbank") + num_jobs = min(15, os.cpu_count()) + + dataset_parts = ( + "train", + "dev", + "test", + ) + prefix = "thchs_30" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if "train" in partition: + cut_set = ( + (cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)) + if speed_perturb + else cut_set + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + compute_fbank_thchs30( + num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb + ) diff --git a/egs/multi_zh-hans/ASR/local/prepare_char.py b/egs/multi_zh-hans/ASR/local/prepare_char.py new file mode 120000 index 000000000..be7da61af --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/prepare_char.py @@ -0,0 +1 @@ +../../../wenetspeech/ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py b/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py new file mode 100755 index 000000000..020800c15 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/prepare_for_bpe_model.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script tokenizes the training transcript by CJK characters +# and saves the result to transcript_chars.txt, which is used +# to train the BPE model later. + +import argparse +from pathlib import Path + +from tqdm.auto import tqdm + +from icefall.utils import tokenize_by_CJK_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Output directory. + The generated transcript_chars.txt is saved to this directory. + """, + ) + + parser.add_argument( + "--text", + type=str, + help="WenetSpeech training transcript.", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + text = Path(args.text) + + assert lang_dir.exists() and text.exists(), f"{lang_dir} or {text} does not exist!" + + transcript_path = lang_dir / "transcript_chars.txt" + + with open(text, "r", encoding="utf-8") as fin: + with open(transcript_path, "w+", encoding="utf-8") as fout: + for line in fin: + fout.write(tokenize_by_CJK_char(line) + "\n") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/local/prepare_lang.py b/egs/multi_zh-hans/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/local/prepare_lang_bpe.py b/egs/multi_zh-hans/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/local/preprocess_kespeech.py b/egs/multi_zh-hans/ASR/local/preprocess_kespeech.py new file mode 100755 index 000000000..20274263f --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/preprocess_kespeech.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# Copyright 2021 Johns Hopkins University (Piotr Żelasko) +# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) +# Copyright 2023 Xiaomi Corp. (Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import re +from pathlib import Path + +from lhotse import CutSet, SupervisionSegment +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall import setup_logger + +# Similar text filtering and normalization procedure as in: +# https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh + + +def normalize_text( + utt: str, + punct_pattern=re.compile(r"<(PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), + whitespace_pattern=re.compile(r"\s\s+"), +) -> str: + return whitespace_pattern.sub(" ", punct_pattern.sub("", utt)) + + +def has_no_oov( + sup: SupervisionSegment, + oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER|SPOKEN_NOISE)>"), +) -> bool: + return oov_pattern.search(sup.text) is None + + +def preprocess_kespeech(speed_perturb: bool = False): + src_dir = Path("data/manifests/kespeech") + output_dir = Path("data/fbank/kespeech") + output_dir.mkdir(exist_ok=True) + + # Note: By default, we preprocess all sub-parts. + # You can delete those that you don't need. + # For instance, if you don't want to use the test subpart, just remove + # the line below containing "test" + dataset_parts = ( + "dev_phase1", + "dev_phase2", + "test", + "train_phase1", + "train_phase2", + ) + + logging.info("Loading manifest (may take 10 minutes)") + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + suffix="jsonl.gz", + prefix="kespeech-asr", + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + logging_threshold = 50 + logging_count = 0 + + for partition, m in manifests.items(): + logging.info(f"Processing {partition}") + raw_cuts_path = output_dir / f"kespeech-asr_cuts_{partition}_raw.jsonl.gz" + if raw_cuts_path.is_file(): + logging.info(f"{partition} already exists - skipping") + continue + + # Note this step makes the recipe different than LibriSpeech: + # We must filter out some utterances and remove punctuation + # to be consistent with Kaldi. + logging.info("Filtering OOV utterances from supervisions") + m["supervisions"] = m["supervisions"].filter(has_no_oov) + logging.info(f"Normalizing text in {partition}") + for sup in m["supervisions"]: + orig_text = sup.text + sup.text = normalize_text(sup.text) + if logging_count < logging_threshold and len(orig_text) != len(sup.text): + logging_count += 1 + logging.info( + f"\nOriginal text vs normalized text:\n{orig_text}\n{sup.text}" + ) + + # Create long-recording cut manifests. + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + # Run data augmentation that needs to be done in the + # time domain. + if partition not in [ + "dev_phase1", + "dev_phase2", + "test", + ]: + if speed_perturb: + logging.info( + f"Speed perturb for {partition} with factors 0.9 and 1.1 " + "(Perturbing may take 8 minutes and saving may take 20 minutes)" + ) + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) + logging.info(f"Saving to {raw_cuts_path}") + cut_set.to_file(raw_cuts_path) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--speed-perturb", + type=bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + return parser.parse_args() + + +def main(): + setup_logger(log_filename="./log-preprocess-kespeech") + + args = get_args() + preprocess_kespeech(speed_perturb=args.speed_perturb) + logging.info("Done") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/local/text2token.py b/egs/multi_zh-hans/ASR/local/text2token.py new file mode 120000 index 000000000..ce5cfd537 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/text2token.py @@ -0,0 +1 @@ +../../../wenetspeech/ASR/local/text2token.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/local/train_bpe_model.py b/egs/multi_zh-hans/ASR/local/train_bpe_model.py new file mode 100755 index 000000000..976ea0ba8 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/train_bpe_model.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + parser.add_argument( + "--byte-fallback", + type=bool, + default=True, + help="Enable byte fallback for BPE model.", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + character_coverage = 0.98 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + byte_fallback=args.byte_fallback, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/local/validate_bpe_lexicon.py b/egs/multi_zh-hans/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/multi_zh-hans/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh new file mode 100755 index 000000000..5d0fe66a4 --- /dev/null +++ b/egs/multi_zh-hans/ASR/prepare.sh @@ -0,0 +1,373 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 +num_splits=100 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +vocab_sizes=( + 2000 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +log "Dataset: musan" +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Soft link fbank of musan" + mkdir -p data/fbank + if [ -e ../../librispeech/ASR/data/fbank/.musan.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_feats) . + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_cuts.jsonl.gz) . + cd ../.. + else + log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 4 --stop-stage 4" + exit 1 + fi +fi + +log "Dataset: THCHS-30" +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare THCHS-30" + if [ ! -d $dl_dir/thchs30 ]; then + log "Downloading THCHS-30" + lhotse download thchs30 $dl_dir/thchs30 + fi + + if [ ! -f data/manifests/.thchs30.done ]; then + mkdir -p data/manifests + lhotse prepare thchs-30 $dl_dir/thchs30 data/manifests/thchs30 + touch data/manifests/.thchs30.done + fi + + if [ ! -f data/fbank/.thchs30.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_thchs30.py + touch data/fbank/.thchs30.done + fi +fi + +log "Dataset: AISHELL-1" +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare AISHELL-1" + if [ -e ../../aishell/ASR/data/fbank/.aishell.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../aishell/ASR/data/fbank/aishell_feats_train) . + ln -svf $(realpath ../../../../aishell/ASR/data/fbank/aishell_feats_dev) . + ln -svf $(realpath ../../../../aishell/ASR/data/fbank/aishell_feats_test) . + ln -svf $(realpath ../../../../aishell/ASR/data/fbank/aishell_cuts_train.jsonl.gz) . + ln -svf $(realpath ../../../../aishell/ASR/data/fbank/aishell_cuts_dev.jsonl.gz) . + ln -svf $(realpath ../../../../aishell/ASR/data/fbank/aishell_cuts_test.jsonl.gz) . + cd ../.. + else + log "Abort! Please run ../../aishell/ASR/prepare.sh --stage 3 --stop-stage 3" + exit 1 + fi +fi + +log "Dataset: AISHELL-2" +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare AISHELL-2" + if [ -e ../../aishell/ASR/data/fbank/.aishell2.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats_train) . + ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats_dev) . + ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats_test) . + ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_train.jsonl.gz) . + ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_dev.jsonl.gz) . + ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_test.jsonl.gz) . + cd ../.. + else + log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" + exit 1 + fi +fi + +log "Dataset: AISHELL-4" +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare AISHELL-4" + if [ -e ../../aishell/ASR/data/fbank/.aishell4.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train) . + ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_dev) . + ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_test) . + ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_L.jsonl.gz) . + ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_M.jsonl.gz) . + ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) . + ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_dev.jsonl.gz) . + ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) . + cd ../.. + else + log "Abort! Please run ../../aishell4/ASR/prepare.sh --stage 3 --stop-stage 3" + exit 1 + fi +fi + +log "Dataset: ST-CMDS" +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare ST-CMDS" + if [ ! -f $dl_dir/stcmds/ST-CMDS-20170001_1-OS.tar.gz ]; then + log "Downloading ST-CMDS" + lhotse download stcmds $dl_dir/stcmds + fi + + if [ ! -f data/manifests/.stcmds.done ]; then + mkdir -p data/manifests + lhotse prepare stcmds $dl_dir/stcmds data/manifests/stcmds + touch data/manifests/.stcmds.done + fi + + if [ ! -f data/fbank/.stcmds.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_stcmds.py + touch data/fbank/.stcmds.done + fi +fi + + +log "Dataset: Primewords" +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare Primewords" + if [ ! -f $dl_dir/primewords/primewords_md_2018_set1.tar.gz ]; then + log "Downloading Primewords" + lhotse download primewords $dl_dir/primewords + fi + + if [ ! -f data/manifests/.stcmds.done ]; then + mkdir -p data/manifests + lhotse prepare stcmds $dl_dir/primewords data/manifests/primewords + touch data/manifests/.primewords.done + fi + + if [ ! -f data/fbank/.primewords.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_primewords.py + touch data/fbank/.primewords.done + fi +fi + +log "Dataset: MagicData" +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare MagicData" + if [ ! -f $dl_dir/magicdata/train_set.tar.gz ]; then + log "Downloading MagicData" + lhotse download magicdata $dl_dir/magicdata + fi + + if [ ! -f data/manifests/.magicdata.done ]; then + mkdir -p data/manifests + lhotse prepare magicdata $dl_dir/magicdata data/manifests/magicdata + touch data/manifests/.magicdata.done + fi + + if [ ! -f data/fbank/.magicdata.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_magicdata.py + touch data/fbank/.magicdata.done + fi +fi + +log "Dataset: aidatatang_200zh" +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Prepare aidatatang_200zh" + if [ -e ../../aidatatang_200zh/ASR/data/fbank/.aidatatang_200zh.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../aidatatang_200zh/ASR/data/fbank/aidatatang_feats_train) . + ln -svf $(realpath ../../../../aidatatang_200zh/ASR/data/fbank/aidatatang_feats_dev) . + ln -svf $(realpath ../../../../aidatatang_200zh/ASR/data/fbank/aidatatang_feats_test) . + ln -svf $(realpath ../../../../aidatatang_200zh/ASR/data/fbank/aidatatang_cuts_train.jsonl.gz) . + ln -svf $(realpath ../../../../aidatatang_200zh/ASR/data/fbank/aidatatang_cuts_dev.jsonl.gz) . + ln -svf $(realpath ../../../../aidatatang_200zh/ASR/data/fbank/aidatatang_cuts_test.jsonl.gz) . + cd ../.. + else + log "Abort! Please run ../../aidatatang_200zh/ASR/prepare.sh --stage 4 --stop-stage 4" + exit 1 + fi +fi + +log "Dataset: Ali-Meeting" +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Prepare Ali-Meeting" + if [ -e ../../alimeeting/ASR/data/fbank/.fbank.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../alimeeting/ASR/data/fbank/alimeeting-far_feats_train) . + ln -svf $(realpath ../../../../alimeeting/ASR/data/fbank/alimeeting-far_feats_eval) . + ln -svf $(realpath ../../../../alimeeting/ASR/data/fbank/alimeeting-far_feats_test) . + ln -svf $(realpath ../../../../alimeeting/ASR/data/fbank/alimeeting-far_cuts_train.jsonl.gz) . + ln -svf $(realpath ../../../../alimeeting/ASR/data/fbank/alimeeting-far_cuts_eval.jsonl.gz) . + ln -svf $(realpath ../../../../alimeeting/ASR/data/fbank/alimeeting-far_cuts_test.jsonl.gz) . + cd ../.. + else + log "Abort! Please run ../../alimeeting/ASR/prepare.sh --stage 5 --stop-stage 5" + exit 1 + fi +fi + +log "Dataset: WenetSpeech" +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Prepare WenetSpeech" + if [ -e ../../wenetspeech/ASR/data/fbank/.preprocess_complete ]; then + cd data/fbank + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) . + + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/L_split_1000) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/*.lca) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/) ./wenetspeech + cd ../.. + else + log "Abort! Please run ../../wenetspeech/ASR/prepare.sh" + exit 1 + fi + + if [ -d ../../wenetspeech/ASR/data/lang_char/ ]; then + cd data + cp -r ../../../../wenetspeech/ASR/data/lang_char . + cd .. + else + log "Abort! Please run ../../wenetspeech/ASR/prepare.sh" + exit 1 + fi +fi + +log "Dataset: KeSpeech" +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Prepare KeSpeech" + if [ ! -d $dl_dir/KeSpeech ]; then + log "Abort! Please download KeSpeech first." + log "KeSpeech download link: https://github.com/KeSpeech/KeSpeech" + exit 1 + fi + + if [ ! -f data/manifests/.kespeech.done ]; then + mkdir -p data/manifests + lhotse prepare kespeech -j 16 $dl_dir/KeSpeech data/manifests/kespeech + touch data/manifests/.kespeech.done + fi + + if [ ! -f data/fbank/.kespeech.done ]; then + mkdir -p data/fbank + + log "Preprocess KeSpeech manifest" + if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then + python3 ./local/preprocess_kespeech.py + touch data/fbank/.kespeech_preprocess_complete + fi + + if [ -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then + log "Spliting KeSpeech train_phase1" + lhotse split ${num_splits} \ + data/fbank/kespeech/kespeech-asr_cuts_train_phase1_raw.jsonl.gz \ + data/fbank/kespeech/train_phase1_split_${num_splits} + touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done + fi + + if [ -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then + log "Spliting KeSpeech train_phase2" + lhotse split ${num_splits} \ + data/fbank/kespeech/kespeech-asr_cuts_train_phase2_raw.jsonl.gz \ + data/fbank/kespeech/train_phase2_split_${num_splits} + touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done + fi + + log "Compute KeSpeech fbank for train_phase1" + ./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase1 + + log "Compute KeSpeech fbank for train_phase2" + ./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase2 + + log "Compute KeSpeech fbank for test/dev" + ./local/compute_fbank_kespeech_dev_test.py + + touch data/fbank/.kespeech.done + fi +fi + +if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then + log "Stage 13: BPE model training (note that we use transcripts of wenetspeech only for BPE training)" + ./local/prepare_for_bpe_model.py --lang-dir ./data/lang_char --text ./data/lang_char/text + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + mkdir -p $lang_dir + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --transcript ./data/lang_char/transcript_chars.txt \ + --vocab-size $vocab_size + + ./local/bpe_model_to_tokens.py $lang_dir/bpe.model > $lang_dir/tokens.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + cp data/lang_char/words.txt $lang_dir + + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi + +if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then + log "Stage 14: Prepare G (note that we use ngram lm of wenetspeech only for G preparation)" + + if [ -d ../../wenetspeech/ASR/data/lang_char/ ]; then + cd data + ln -s ../../../../wenetspeech/ASR/data/lm . + cd .. + else + log "Abort! Please run ../../wenetspeech/ASR/prepare.sh" + exit 1 + fi +fi + +if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then + log "Stage 15: Compile LG" + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + python ./local/compile_lg.py --lang-dir $lang_dir + done +fi + + diff --git a/egs/multi_zh-hans/ASR/shared b/egs/multi_zh-hans/ASR/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/multi_zh-hans/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py new file mode 100644 index 000000000..b1b7bff93 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py @@ -0,0 +1,388 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class AsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=300.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl diff --git a/egs/multi_zh-hans/ASR/zipformer/beam_search.py b/egs/multi_zh-hans/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..8e2c0a65c --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py new file mode 100755 index 000000000..f501c3c30 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -0,0 +1,828 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from multi_dataset import MultiDataset +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_2000/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_2000", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list(str(text).replace(" ", "")) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + hyp_text = "".join(hyp_words) + this_batch.append((cut_id, ref_text, hyp_text)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args.manifest_dir) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" + ) + return T > 0 + + test_sets_cuts = multi_dataset.test_cuts() + + test_sets = test_sets_cuts.keys() + test_dl = [ + data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) + for cuts_name in test_sets + ] + + for test_set, test_dl in zip(test_sets, test_dl): + logging.info(f"Start decoding test set: {test_set}") + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/zipformer/decoder.py b/egs/multi_zh-hans/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/encoder_interface.py b/egs/multi_zh-hans/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming.py b/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/export-onnx.py b/egs/multi_zh-hans/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/export.py b/egs/multi_zh-hans/ASR/zipformer/export.py new file mode 100755 index 000000000..723288191 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/export.py @@ -0,0 +1,541 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 \ + --jit 1 + +It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. +You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. + +Check ./jit_pretrained_streaming.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +- For non-streaming model: + +To use the generated file with `zipformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_2000/bpe.model + +- For streaming model: + +To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + + # simulated streaming decoding + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_2000/bpe.model + + # chunk-wise streaming decoding + ./zipformer/streaming_decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_2000/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +- non-streaming model: +https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ + # You will find the pre-trained models in exp dir +""" + +import argparse +import logging +import re +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor, nn +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, str2bool + + +def num_tokens( + token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") +) -> int: + """Return the number of tokens excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + symbols = token_table.symbols + ans = [] + for s in symbols: + if not disambig_pattern.match(s): + ans.append(token_table[s]) + num_tokens = len(ans) + if 0 in ans: + num_tokens -= 1 + return num_tokens + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_2000/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +class StreamingEncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + assert len(encoder.chunk_size) == 1, encoder.chunk_size + assert len(encoder.left_context_frames) == 1, encoder.left_context_frames + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """Streaming forward for encoder_embed and encoder. + + Args: + features: (N, T, C) + feature_lengths: (N,) + states: a list of Tensors + + Returns encoder outputs, output lengths, and updated states. + """ + chunk_size = self.chunk_size + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lengths, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + # if torch.cuda.is_available(): + # device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + + # Wrap encoder and encoder_embed as a module + if params.causal: + model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) + chunk_size = model.encoder.chunk_size + left_context_len = model.encoder.left_context_len + filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + else: + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh-hans/ASR/zipformer/generate_averaged_model.py b/egs/multi_zh-hans/ASR/zipformer/generate_averaged_model.py new file mode 100755 index 000000000..68111fad7 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/generate_averaged_model.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) use the checkpoint exp_dir/epoch-xxx.pt +./zipformer/generate_averaged_model.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp + +It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15.pt")`. + +(2) use the checkpoint exp_dir/checkpoint-iter.pt +./zipformer/generate_averaged_model.py \ + --iter 22000 \ + --avg 5 \ + --exp-dir ./zipformer/exp + +It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5.pt")`. +""" + + +import argparse +from pathlib import Path + +import k2 +import torch +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + print("Script started") + + device = torch.device("cpu") + print(f"Device: {device}") + + symbol_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = symbol_table[""] + params.unk_id = symbol_table[""] + params.vocab_size = len(symbol_table) + + print("About to create model") + model = get_model(params) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + print( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + print( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained.py b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained.py new file mode 120000 index 000000000..25108391f --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_ctc.py b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_ctc.py new file mode 120000 index 000000000..9a8da5844 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_streaming.py b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_streaming.py new file mode 120000 index 000000000..1962351e9 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/joiner.py b/egs/multi_zh-hans/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/model.py b/egs/multi_zh-hans/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py new file mode 100644 index 000000000..b1920e62e --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py @@ -0,0 +1,316 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import glob +import logging +import re +from pathlib import Path +from typing import Dict, List + +import lhotse +from lhotse import CutSet, load_manifest_lazy + + +class MultiDataset: + def __init__(self, fbank_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files: + - aidatatang_cuts_train.jsonl.gz + - aishell_cuts_train.jsonl.gz + - aishell2_cuts_train.jsonl.gz + - aishell4_cuts_train_L.jsonl.gz + - aishell4_cuts_train_M.jsonl.gz + - aishell4_cuts_train_S.jsonl.gz + - alimeeting-far_cuts_train.jsonl.gz + - magicdata_cuts_train.jsonl.gz + - primewords_cuts_train.jsonl.gz + - stcmds_cuts_train.jsonl.gz + - thchs_30_cuts_train.jsonl.gz + - kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz + - kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz + - wenetspeech/cuts_L.jsonl.gz + """ + self.fbank_dir = Path(fbank_dir) + + def train_cuts(self) -> CutSet: + logging.info("About to get multidataset train cuts") + + # THCHS-30 + logging.info("Loading THCHS-30 in lazy mode") + thchs_30_cuts = load_manifest_lazy( + self.fbank_dir / "thchs_30_cuts_train.jsonl.gz" + ) + + # AISHELL-1 + logging.info("Loading Aishell-1 in lazy mode") + aishell_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_train.jsonl.gz" + ) + + # AISHELL-2 + logging.info("Loading Aishell-2 in lazy mode") + aishell_2_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_train.jsonl.gz" + ) + + # AISHELL-4 + logging.info("Loading Aishell-4 in lazy mode") + aishell_4_L_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz" + ) + aishell_4_M_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz" + ) + aishell_4_S_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz" + ) + + # ST-CMDS + logging.info("Loading ST-CMDS in lazy mode") + stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz") + + # Primewords + logging.info("Loading Primewords in lazy mode") + primewords_cuts = load_manifest_lazy( + self.fbank_dir / "primewords_cuts_train.jsonl.gz" + ) + + # MagicData + logging.info("Loading MagicData in lazy mode") + magicdata_cuts = load_manifest_lazy( + self.fbank_dir / "magicdata_cuts_train.jsonl.gz" + ) + + # Aidatatang_200zh + logging.info("Loading Aidatatang_200zh in lazy mode") + aidatatang_200zh_cuts = load_manifest_lazy( + self.fbank_dir / "aidatatang_cuts_train.jsonl.gz" + ) + + # Ali-Meeting + logging.info("Loading Ali-Meeting in lazy mode") + alimeeting_cuts = load_manifest_lazy( + self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz" + ) + + # WeNetSpeech + logging.info("Loading WeNetSpeech in lazy mode") + wenetspeech_L_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz" + ) + + # KeSpeech + logging.info("Loading KeSpeech in lazy mode") + kespeech_1_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz" + ) + kespeech_2_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz" + ) + + return CutSet.mux( + thchs_30_cuts, + aishell_cuts, + aishell_2_cuts, + aishell_4_L_cuts, + aishell_4_M_cuts, + aishell_4_S_cuts, + stcmds_cuts, + primewords_cuts, + magicdata_cuts, + aidatatang_200zh_cuts, + alimeeting_cuts, + wenetspeech_L_cuts, + kespeech_1_cuts, + kespeech_2_cuts, + weights=[ + len(thchs_30_cuts), + len(aishell_cuts), + len(aishell_2_cuts), + len(aishell_4_L_cuts), + len(aishell_4_M_cuts), + len(aishell_4_S_cuts), + len(stcmds_cuts), + len(primewords_cuts), + len(magicdata_cuts), + len(aidatatang_200zh_cuts), + len(alimeeting_cuts), + len(wenetspeech_L_cuts), + len(kespeech_1_cuts), + len(kespeech_2_cuts), + ], + ) + + def dev_cuts(self) -> CutSet: + logging.info("About to get multidataset dev cuts") + + # Aidatatang_200zh + logging.info("Loading Aidatatang_200zh DEV set in lazy mode") + aidatatang_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aidatatang_cuts_dev.jsonl.gz" + ) + + # AISHELL + logging.info("Loading Aishell DEV set in lazy mode") + aishell_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_dev.jsonl.gz" + ) + + # AISHELL-2 + logging.info("Loading Aishell-2 DEV set in lazy mode") + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) + + # Ali-Meeting + logging.info("Loading Ali-Meeting DEV set in lazy mode") + alimeeting_dev_cuts = load_manifest_lazy( + self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz" + ) + + # MagicData + logging.info("Loading MagicData DEV set in lazy mode") + magicdata_dev_cuts = load_manifest_lazy( + self.fbank_dir / "magicdata_cuts_dev.jsonl.gz" + ) + + # KeSpeech + logging.info("Loading KeSpeech DEV set in lazy mode") + kespeech_dev_phase1_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz" + ) + kespeech_dev_phase2_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz" + ) + + # WeNetSpeech + logging.info("Loading WeNetSpeech DEV set in lazy mode") + wenetspeech_dev_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz" + ) + + return wenetspeech_dev_cuts + # return [ + # aidatatang_dev_cuts, + # aishell_dev_cuts, + # aishell2_dev_cuts, + # alimeeting_dev_cuts, + # magicdata_dev_cuts, + # kespeech_dev_phase1_cuts, + # kespeech_dev_phase2_cuts, + # wenetspeech_dev_cuts, + # ] + + def test_cuts(self) -> Dict[str, CutSet]: + logging.info("About to get multidataset test cuts") + + # Aidatatang_200zh + logging.info("Loading Aidatatang_200zh set in lazy mode") + aidatatang_test_cuts = load_manifest_lazy( + self.fbank_dir / "aidatatang_cuts_test.jsonl.gz" + ) + aidatatang_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aidatatang_cuts_dev.jsonl.gz" + ) + + # AISHELL + logging.info("Loading Aishell set in lazy mode") + aishell_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_test.jsonl.gz" + ) + aishell_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_dev.jsonl.gz" + ) + + # AISHELL-2 + logging.info("Loading Aishell-2 set in lazy mode") + aishell2_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_test.jsonl.gz" + ) + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) + + # AISHELL-4 + logging.info("Loading Aishell-4 TEST set in lazy mode") + aishell4_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_test.jsonl.gz" + ) + + # Ali-Meeting + logging.info("Loading Ali-Meeting set in lazy mode") + alimeeting_test_cuts = load_manifest_lazy( + self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz" + ) + alimeeting_eval_cuts = load_manifest_lazy( + self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz" + ) + + # MagicData + logging.info("Loading MagicData set in lazy mode") + magicdata_test_cuts = load_manifest_lazy( + self.fbank_dir / "magicdata_cuts_test.jsonl.gz" + ) + magicdata_dev_cuts = load_manifest_lazy( + self.fbank_dir / "magicdata_cuts_dev.jsonl.gz" + ) + + # KeSpeech + logging.info("Loading KeSpeech set in lazy mode") + kespeech_test_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz" + ) + kespeech_dev_phase1_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz" + ) + kespeech_dev_phase2_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz" + ) + + # WeNetSpeech + logging.info("Loading WeNetSpeech set in lazy mode") + wenetspeech_test_meeting_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz" + ) + wenetspeech_test_net_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz" + ) + wenetspeech_dev_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz" + ) + + return { + "aidatatang_test": aidatatang_test_cuts, + "aidatatang_dev": aidatatang_dev_cuts, + "alimeeting_test": alimeeting_test_cuts, + "alimeeting_eval": alimeeting_eval_cuts, + "aishell_test": aishell_test_cuts, + "aishell_dev": aishell_dev_cuts, + "aishell-2_test": aishell2_test_cuts, + "aishell-2_dev": aishell2_dev_cuts, + "aishell-4": aishell4_test_cuts, + "magicdata_test": magicdata_test_cuts, + "magicdata_dev": magicdata_dev_cuts, + "kespeech-asr_test": kespeech_test_cuts, + "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts, + "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts, + "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts, + "wenetspeech-net_test": wenetspeech_test_net_cuts, + "wenetspeech_dev": wenetspeech_dev_cuts, + } diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_check.py b/egs/multi_zh-hans/ASR/zipformer/onnx_check.py new file mode 120000 index 000000000..f3dd42004 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py b/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py new file mode 120000 index 000000000..0573b88c5 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming.py b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming.py new file mode 120000 index 000000000..cfea104c2 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/optim.py b/egs/multi_zh-hans/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/pretrained.py b/egs/multi_zh-hans/ASR/zipformer/pretrained.py new file mode 100755 index 000000000..69ff382da --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/pretrained.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_2000/tokens.txt \ + --epoch 23 \ + --avg 1 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_2000/tokens.txt \ + --epoch 23 \ + --avg 1 + +Usage of this script: + +- For non-streaming model: + +(1) greedy search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bpe_2000/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens ./data/lang_bpe_2000/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens ./data/lang_bpe_2000/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +- For streaming model: + +(1) greedy search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_2000/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_2000/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_2000/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +You can also use `./zipformer/exp/epoch-xx.pt`. + +Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from beam_search import ( + fast_beam_search_one_best, + greedy_search_batch, + modified_beam_search, +) +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.utils import make_pad_mask + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + token_table = k2.SymbolTable.from_file(params.tokens) + + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # model forward + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + + hyps = [] + msg = f"Using {params.method}" + logging.info(msg) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + else: + raise ValueError(f"Unsupported method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + s += f"{filename}:\n{hyp}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh-hans/ASR/zipformer/scaling.py b/egs/multi_zh-hans/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/scaling_converter.py b/egs/multi_zh-hans/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/streaming_beam_search.py b/egs/multi_zh-hans/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py b/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py new file mode 120000 index 000000000..13fd02a78 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_decode.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/subsampling.py b/egs/multi_zh-hans/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py new file mode 100755 index 000000000..4f2d728be --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -0,0 +1,1385 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from multi_dataset import MultiDataset +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_2000/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args.manifest_dir) + + train_cuts = multi_dataset.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = data_module.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = multi_dataset.dev_cuts() + valid_dl = data_module.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh-hans/ASR/zipformer/zipformer.py b/egs/multi_zh-hans/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file From 7cc2dae9409c76e54ef32b31fe647c5b30409cea Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 13 Sep 2023 12:39:49 +0800 Subject: [PATCH 033/113] Fixes to incorporate with the latest Lhotse release (#1249) --- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 6 +++--- egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py | 6 +++--- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 6 +++--- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 6 +++--- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 6 +++--- egs/ami/SURT/dprnn_zipformer/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless7/asr_datamodule.py | 6 +++--- egs/csj/ASR/local/utils/asr_datamodule.py | 6 +++--- egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py | 6 +++--- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 6 +++--- egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py | 2 +- .../ASR/pruned2_knowledge/asr_datamodule.py | 6 +++--- .../ASR/pruned_transducer_stateless7/gigaspeech.py | 6 +++--- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 6 +++--- egs/mgb2/ASR/conformer_ctc/asr_datamodule.py | 6 +++--- egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py | 6 +++--- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 6 +++--- .../ASR/transducer_stateless/asr_datamodule.py | 8 +++----- egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py | 10 +++++----- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 6 +++--- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 6 +++--- egs/yesno/ASR/tdnn/asr_datamodule.py | 6 +++--- requirements-ci.txt | 1 + test/test_ali.py | 4 ++-- 24 files changed, 67 insertions(+), 68 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index 167d5e15e..49a697bfd 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -37,7 +37,7 @@ from lhotse.dataset import ( DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures @@ -291,8 +291,8 @@ class Aidatatang_200zhAsrDataModule: drop_last=True, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index efb32336a..180930747 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -30,7 +30,7 @@ from lhotse.dataset import ( DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures @@ -278,8 +278,8 @@ class AishellAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py index 0f383a244..af37cc175 100644 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -299,8 +299,8 @@ class AiShell2AsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index d980a857f..da9da371e 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -30,7 +30,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 for AudioSamples @@ -310,8 +310,8 @@ class Aishell4AsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py index a9a4675a9..4799da19d 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -37,7 +37,7 @@ from lhotse.dataset import ( DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures @@ -292,8 +292,8 @@ class AlimeetingAsrDataModule: drop_last=True, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py index ec8106bc3..3dd786d33 100644 --- a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py +++ b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py @@ -257,7 +257,7 @@ class AmiAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") + logging.info("Using SimpleCutSampler.") train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py index 2c37244a4..73f2f1dce 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -30,7 +30,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -311,8 +311,8 @@ class CommonVoiceAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/csj/ASR/local/utils/asr_datamodule.py b/egs/csj/ASR/local/utils/asr_datamodule.py index 619820a75..272486227 100644 --- a/egs/csj/ASR/local/utils/asr_datamodule.py +++ b/egs/csj/ASR/local/utils/asr_datamodule.py @@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -339,8 +339,8 @@ class CSJAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index 9437c935c..9d6e3c42a 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -27,7 +27,7 @@ from lhotse.dataset import ( DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures @@ -264,8 +264,8 @@ class GigaSpeechAsrDataModule: drop_last=True, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 4d5d2b8f9..29e72b408 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -30,7 +30,7 @@ from lhotse.dataset import ( DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures @@ -297,8 +297,8 @@ class GigaSpeechAsrDataModule: drop_last=True, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py index 51df91598..a72df89e0 100644 --- a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py +++ b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py @@ -259,7 +259,7 @@ class LibriCssAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") + logging.info("Using SimpleCutSampler.") train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index b839a4a4c..f8f558ce1 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -31,7 +31,7 @@ from lhotse.dataset import ( CutMix, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures @@ -290,8 +290,8 @@ class LibriSpeechAsrDataModule: drop_last=True, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py index 5c01d7190..75e153cb0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py @@ -30,7 +30,7 @@ from lhotse.dataset import ( DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures @@ -297,8 +297,8 @@ class GigaSpeechAsrDataModule: drop_last=True, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index c47964b07..20df469da 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -314,8 +314,8 @@ class LibriSpeechAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py index 8242e986d..442ff85c2 100644 --- a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py @@ -17,7 +17,7 @@ from lhotse.dataset import ( DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures @@ -270,8 +270,8 @@ class MGB2AsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py index b1b7bff93..3d58ebf3a 100644 --- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py @@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -300,8 +300,8 @@ class AsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py index 2240c1c1d..39beffdcf 100644 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -30,7 +30,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 for AudioSamples @@ -311,8 +311,8 @@ class TAL_CSASRAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py index c647392f0..28d0d3826 100644 --- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py +++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py @@ -28,7 +28,7 @@ from lhotse.dataset import ( CutMix, DynamicBucketingSampler, K2SpeechRecognitionDataset, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures @@ -259,8 +259,8 @@ class TedLiumAsrDataModule: drop_last=True, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, @@ -282,7 +282,6 @@ class TedLiumAsrDataModule: return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] if self.args.concatenate_cuts: transforms = [ @@ -322,7 +321,6 @@ class TedLiumAsrDataModule: return valid_dl def test_dataloaders(self, cuts_test: CutSet) -> DataLoader: - logging.debug("About to create test dataset") if self.args.on_the_fly_feats: test = K2SpeechRecognitionDataset( diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py index 51ca4cc6e..7c299d601 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -30,7 +30,7 @@ from lhotse.dataset import ( DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures @@ -225,8 +225,8 @@ class TimitAsrDataModule(DataModule): drop_last=True, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, @@ -267,7 +267,7 @@ class TimitAsrDataModule(DataModule): cut_transforms=transforms, return_cuts=self.args.return_cuts, ) - valid_sampler = SingleCutSampler( + valid_sampler = SimpleCutSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, @@ -298,7 +298,7 @@ class TimitAsrDataModule(DataModule): else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = SingleCutSampler(cuts_test, max_duration=self.args.max_duration) + sampler = SimpleCutSampler(cuts_test, max_duration=self.args.max_duration) logging.debug("About to create test dataloader") test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) test_loaders.append(test_dl) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 746b212ff..c5967f10a 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -37,7 +37,7 @@ from lhotse.dataset import ( DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures @@ -296,8 +296,8 @@ class WenetSpeechAsrDataModule: drop_last=True, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py index 55d5f4636..6362ab7cd 100644 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -32,7 +32,7 @@ from lhotse.dataset import ( DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples @@ -299,8 +299,8 @@ class Xbmu_AmdoAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index ada8c1a6c..dc66b217d 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -26,7 +26,7 @@ from lhotse.dataset import ( DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures from torch.utils.data import DataLoader @@ -196,8 +196,8 @@ class YesNoAsrDataModule(DataModule): drop_last=True, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/requirements-ci.txt b/requirements-ci.txt index 21d33001c..2433e190b 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -10,6 +10,7 @@ graphviz==0.19.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html torch==1.13.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.13.1+cpu +six -f https://k2-fsa.org/nightly/ k2==1.23.4.dev20230319+cpu.torch1.13.1 diff --git a/test/test_ali.py b/test/test_ali.py index b107a6d80..d607e40aa 100755 --- a/test/test_ali.py +++ b/test/test_ali.py @@ -26,7 +26,7 @@ from pathlib import Path from lhotse import CutSet, load_manifest -from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler +from lhotse.dataset import K2SpeechRecognitionDataset, SimpleCutSampler from lhotse.dataset.collation import collate_custom_field from torch.utils.data import DataLoader @@ -44,7 +44,7 @@ def get_dataloader(): cuts = load_manifest(cuts_json) print(cuts[0]) cuts = cuts.with_features_path_prefix(egs_dir) - sampler = SingleCutSampler( + sampler = SimpleCutSampler( cuts, max_duration=10, shuffle=False, From fba17106228badbc77c5aa75c1a1263877067906 Mon Sep 17 00:00:00 2001 From: docterstrange <44291127+docterstrange@users.noreply.github.com> Date: Thu, 14 Sep 2023 09:58:28 +0800 Subject: [PATCH 034/113] modify tal_csasr recipe (#1252) Co-authored-by: zss11 --- egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py index 3bfb832fb..3485d4005 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py @@ -724,12 +724,12 @@ def main(): ) save_results( params=params, - test_set_name=test_set, + test_set_name=test_set + "-zh", results_dict=zh_results_dict, ) save_results( params=params, - test_set_name=test_set, + test_set_name=test_set + "-en", results_dict=en_results_dict, ) From 565d2c2f5b920a4ea16be3c6ea04802c2350691a Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 15 Sep 2023 02:37:53 +0800 Subject: [PATCH 035/113] Minor fixes to the libricss recipe (#1256) --- egs/libricss/SURT/prepare.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/libricss/SURT/prepare.sh b/egs/libricss/SURT/prepare.sh index 028240e44..3d2581d96 100755 --- a/egs/libricss/SURT/prepare.sh +++ b/egs/libricss/SURT/prepare.sh @@ -79,7 +79,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # ln -sfv /path/to/rirs_noises $dl_dir/ # if [ ! -d $dl_dir/rirs_noises ]; then - lhotse download rirs_noises $dl_dir + lhotse download rir-noise $dl_dir/rirs_noises fi fi @@ -89,6 +89,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # to $dl_dir/librispeech. We perform text normalization for the transcripts. # NOTE: Alignments are required for this recipe. mkdir -p data/manifests + lhotse prepare librispeech -p train-clean-100 -p train-clean-360 -p train-other-500 -p dev-clean \ -j 4 --alignments-dir $dl_dir/libri_alignments/LibriSpeech $dl_dir/librispeech data/manifests/ fi @@ -112,7 +113,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then # We assume that you have downloaded the RIRS_NOISES corpus # to $dl_dir/rirs_noises - lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests + lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises/RIRS_NOISES data/manifests fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then From 0c564c6c812bee08ebe7fa402f1668883b7847f3 Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Sun, 17 Sep 2023 13:25:37 +0900 Subject: [PATCH 036/113] Fix typo in README.md (#1257) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a876fb24e..523203aa4 100644 --- a/README.md +++ b/README.md @@ -338,7 +338,7 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder #### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss -The best results for Chinese CER(%) and English WER(%) respectivly (zh: Chinese, en: English): +The best results for Chinese CER(%) and English WER(%) respectively (zh: Chinese, en: English): |decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en | |--|--|--|--|--|--|--| |greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13| From 7e1288af50e699a1c09ad3c6acdd58b9765c5745 Mon Sep 17 00:00:00 2001 From: Tiance Wang Date: Tue, 19 Sep 2023 16:46:36 +0800 Subject: [PATCH 037/113] fix thchs-30 download command (#1260) --- egs/multi_zh-hans/ASR/prepare.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh index 5d0fe66a4..c09b9c1de 100755 --- a/egs/multi_zh-hans/ASR/prepare.sh +++ b/egs/multi_zh-hans/ASR/prepare.sh @@ -49,7 +49,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Prepare THCHS-30" if [ ! -d $dl_dir/thchs30 ]; then log "Downloading THCHS-30" - lhotse download thchs30 $dl_dir/thchs30 + lhotse download thchs-30 $dl_dir/thchs30 fi if [ ! -f data/manifests/.thchs30.done ]; then From bbb03f7962eda9519e0947b92bb573140fdb2a04 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 20 Sep 2023 08:15:54 +0800 Subject: [PATCH 038/113] Update decoder.py (#1262) --- egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index b085a1817..bfd019ff5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -71,6 +71,10 @@ class Decoder(nn.Module): groups=decoder_dim // 4, # group size == 4 bias=False, ) + else: + # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` + # when inference with torch.jit.script and context_size == 1 + self.conv = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ From 45d60ef262fe65fa1a63cbdd7b89658b359f7724 Mon Sep 17 00:00:00 2001 From: l2009312042 Date: Thu, 21 Sep 2023 19:41:10 +0800 Subject: [PATCH 039/113] Update conformer.py (#1200) * Update conformer.py * Update zipformer.py fix bug in get_dynamic_dropout_rate --- .../ASR/pruned_transducer_stateless7_streaming/zipformer.py | 2 +- egs/librispeech/ASR/streaming_conformer_ctc/conformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index a5c422959..c7e45564f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -865,7 +865,7 @@ class ZipformerEncoderLayer(nn.Module): return final_dropout_rate else: return initial_dropout_rate - ( - initial_dropout_rate * final_dropout_rate + initial_dropout_rate - final_dropout_rate ) * (self.batch_count / warmup_period) def forward( diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py index 5fe92172e..be6fabf35 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py @@ -230,7 +230,7 @@ class Conformer(Transformer): x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask ) # (T, B, F) else: - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + x = self.encoder(x, pos_emb, src_key_padding_mask=src_key_padding_mask) # (T, B, F) if self.normalize_before: x = self.after_norm(x) From f5dc957d44350ea0ec9adb81578c32af5e6bb809 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 21 Sep 2023 21:16:14 +0800 Subject: [PATCH 040/113] Fix CI tests (#1266) --- .../ASR/pruned_transducer_stateless7/onnx_pretrained.py | 3 +++ .../ASR/pruned_transducer_stateless7/onnx_pretrained.py | 3 +++ .../onnx_pretrained.py | 3 +++ .../ASR/lstm_transducer_stateless2/onnx_pretrained.py | 3 +++ .../lstm_transducer_stateless2/streaming-onnx-decode.py | 5 +++++ .../ASR/pruned_transducer_stateless3/onnx_pretrained.py | 3 +++ .../ASR/pruned_transducer_stateless3/test_onnx.py | 5 +++++ .../onnx_pretrained-streaming.py | 3 +++ .../ASR/pruned_transducer_stateless7/test_onnx.py | 5 +++++ .../onnx_pretrained.py | 8 ++++++++ .../onnx_pretrained.py | 3 +++ .../ASR/zipformer/onnx_pretrained-streaming.py | 3 +++ egs/librispeech/ASR/zipformer/onnx_pretrained.py | 3 +++ .../ASR/pruned_transducer_stateless2/onnx_check.py | 5 +++++ .../onnx_pretrained-streaming.py | 3 +++ .../ASR/pruned_transducer_stateless5/onnx_pretrained.py | 3 +++ egs/yesno/ASR/tdnn/onnx_pretrained.py | 1 + 17 files changed, 62 insertions(+) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py index 5adb6c16a..a92182e8d 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -151,12 +151,14 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_decoder(self, decoder_model_filename: str): self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -170,6 +172,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py index eee19191e..cf6ddfa36 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -152,12 +152,14 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_decoder(self, decoder_model_filename: str): self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -171,6 +173,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py index 5d7e2dfcd..a6c69d54f 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py @@ -136,6 +136,7 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) self.init_encoder_states() @@ -184,6 +185,7 @@ class OnnxModel: self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -197,6 +199,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py index fb9e121e5..06159e56a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py @@ -129,6 +129,7 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) self.init_encoder_states() @@ -166,6 +167,7 @@ class OnnxModel: self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -179,6 +181,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index 34d2e5630..487fc2114 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -172,30 +172,35 @@ class Model: self.encoder = ort.InferenceSession( args.encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_decoder(self, args): self.decoder = ort.InferenceSession( args.decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_joiner(self, args): self.joiner = ort.InferenceSession( args.joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_joiner_encoder_proj(self, args): self.joiner_encoder_proj = ort.InferenceSession( args.joiner_encoder_proj_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_joiner_decoder_proj(self, args): self.joiner_decoder_proj = ort.InferenceSession( args.joiner_decoder_proj_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index e10915086..de3e03da6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -150,12 +150,14 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_decoder(self, decoder_model_filename: str): self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -169,6 +171,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 810da8da6..b98248128 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -78,6 +78,7 @@ def test_conv2d_subsampling(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -133,6 +134,7 @@ def test_rel_pos(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -220,6 +222,7 @@ def test_conformer_encoder_layer(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -304,6 +307,7 @@ def test_conformer_encoder(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -359,6 +363,7 @@ def test_conformer(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py index 29be4c655..6e290e799 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py @@ -138,6 +138,7 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) self.init_encoder_states() @@ -185,6 +186,7 @@ class OnnxModel: self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -198,6 +200,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py index 1e9b67226..f3f7b1ea9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py @@ -74,6 +74,7 @@ def test_conv2d_subsampling(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -128,6 +129,7 @@ def test_rel_pos(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -204,6 +206,7 @@ def test_zipformer_encoder_layer(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -284,6 +287,7 @@ def test_zipformer_encoder(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -338,6 +342,7 @@ def test_zipformer(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py index 8ff02fbcb..494a34d97 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py @@ -326,41 +326,49 @@ def main(): encoder = ort.InferenceSession( args.encoder_model_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) decoder = ort.InferenceSession( args.decoder_model_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) joiner = ort.InferenceSession( args.joiner_model_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) joiner_encoder_proj = ort.InferenceSession( args.joiner_encoder_proj_model_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) joiner_decoder_proj = ort.InferenceSession( args.joiner_decoder_proj_model_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) lconv = ort.InferenceSession( args.lconv_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) frame_reducer = ort.InferenceSession( args.frame_reducer_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) ctc_output = ort.InferenceSession( args.ctc_output_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) sp = spm.SentencePieceProcessor() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py index 8192e01fd..04861ea37 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -130,6 +130,7 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) self.init_encoder_states() @@ -229,6 +230,7 @@ class OnnxModel: self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -242,6 +244,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index 500b2cd09..e62491444 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -146,6 +146,7 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) self.init_encoder_states() @@ -236,6 +237,7 @@ class OnnxModel: self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -249,6 +251,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained.py b/egs/librispeech/ASR/zipformer/onnx_pretrained.py index 032b07721..334376093 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained.py @@ -151,12 +151,14 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_decoder(self, decoder_model_filename: str): self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -170,6 +172,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py index a46ff5a07..2d46eede1 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py @@ -258,6 +258,7 @@ def main(): encoder_session = ort.InferenceSession( args.onnx_encoder_filename, sess_options=options, + providers=["CPUExecutionProvider"], ) test_encoder(model, encoder_session) @@ -265,6 +266,7 @@ def main(): decoder_session = ort.InferenceSession( args.onnx_decoder_filename, sess_options=options, + providers=["CPUExecutionProvider"], ) test_decoder(model, decoder_session) @@ -272,14 +274,17 @@ def main(): joiner_session = ort.InferenceSession( args.onnx_joiner_filename, sess_options=options, + providers=["CPUExecutionProvider"], ) joiner_encoder_proj_session = ort.InferenceSession( args.onnx_joiner_encoder_proj_filename, sess_options=options, + providers=["CPUExecutionProvider"], ) joiner_decoder_proj_session = ort.InferenceSession( args.onnx_joiner_decoder_proj_filename, sess_options=options, + providers=["CPUExecutionProvider"], ) test_joiner( model, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py index facfc2258..c31db6859 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py @@ -139,6 +139,7 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) self.init_encoder_states() @@ -186,6 +187,7 @@ class OnnxModel: self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -199,6 +201,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py index e7c8b4556..c784853ee 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py @@ -158,12 +158,14 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_decoder(self, decoder_model_filename: str): self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -177,6 +179,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py index b23a2a381..72a1d69c8 100755 --- a/egs/yesno/ASR/tdnn/onnx_pretrained.py +++ b/egs/yesno/ASR/tdnn/onnx_pretrained.py @@ -54,6 +54,7 @@ class OnnxModel: self.model = ort.InferenceSession( nn_model, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) meta = self.model.get_modelmeta().custom_metadata_map From 34e40a86b33102576b3442329421178a487e3ea3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 22 Sep 2023 09:57:15 +0800 Subject: [PATCH 041/113] Fix exporting decoder model to onnx (#1264) * Use torch.jit.script() to export the decoder model See also https://github.com/k2-fsa/sherpa-onnx/issues/327 --- egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py | 1 + egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py | 1 + .../ASR/conv_emformer_transducer_stateless2/export-onnx.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py | 1 + .../ASR/pruned_transducer_stateless5/export-onnx-streaming.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py | 1 + .../ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py | 1 + .../ASR/pruned_transducer_stateless7_streaming/export-onnx.py | 1 + egs/librispeech/ASR/zipformer/export-onnx-streaming.py | 1 + egs/librispeech/ASR/zipformer/export-onnx.py | 1 + egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py | 1 + .../ASR/pruned_transducer_stateless5/export-onnx-streaming.py | 1 + egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py | 1 + 17 files changed, 17 insertions(+) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py index e8211500a..2a9fc57d5 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py @@ -322,6 +322,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py index 0c98885ac..2b9f2293a 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py @@ -330,6 +330,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py index cfd365207..ab046557f 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py @@ -401,6 +401,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py index 89ced388c..2a52e2eec 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py @@ -359,6 +359,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py index 6b6cb893f..c543628ff 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py @@ -356,6 +356,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py index 282238c13..0a2132e56 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py @@ -307,6 +307,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py index 26dea7e11..2685ea95a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -312,6 +312,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py index 549fb13c9..b90d81dcf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py @@ -404,6 +404,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py index fff0fcdd5..02aa24f2c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -335,6 +335,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py index 11c885f4d..b75548f8b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py @@ -329,6 +329,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py index 8653126de..2de56837e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -413,6 +413,7 @@ def export_decoder_model_onnx( context_size = decoder_model.decoder.context_size vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py index 6f84d79b4..d71080760 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -401,6 +401,7 @@ def export_decoder_model_onnx( context_size = decoder_model.decoder.context_size vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index a951aeef3..e2c7d7d95 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -506,6 +506,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index e0d664009..3682f0b62 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -353,6 +353,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py index 760fad974..140b1d37f 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py @@ -315,6 +315,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py index 9a926d7e5..921766ad4 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py @@ -404,6 +404,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py index 68c7cc352..037c7adf1 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -335,6 +335,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, From ef658d691e75041398abb76567c810af1c22c7fc Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 24 Sep 2023 17:06:47 +0800 Subject: [PATCH 042/113] fixes for init value of `diagnostics.TensorDiagnosticOptions` (#1269) * fixes for `diagnostics` Replace `2 ** 22` with `512` as the default value of `diagnostics.TensorDiagnosticOptions` also black formatted some scripts * fixed formatting issues --- .../ASR/pruned_transducer_stateless2/train.py | 3 +- .../ASR/pruned_transducer_stateless2/train.py | 2 +- .../ASR/pruned_transducer_stateless3/train.py | 2 +- .../ASR/pruned_transducer_stateless7/train.py | 2 +- .../pruned_transducer_stateless7/train2.py | 2 +- .../train.py | 2 +- .../ASR/pruned_transducer_stateless5/train.py | 3 +- .../ASR/pruned_transducer_stateless5/train.py | 2 +- .../ASR/pruned_transducer_stateless2/train.py | 3 +- .../pruned_transducer_stateless7/train.py | 2 +- .../ASR/pruned_transducer_stateless7/train.py | 2 +- .../ASR/pruned_transducer_stateless7/train.py | 2 +- .../train.py | 2 +- .../train2.py | 2 +- .../train.py | 2 +- .../train.py | 2 +- .../train2.py | 2 +- .../ASR/pruned2_knowledge/train.py | 2 +- .../ASR/pruned_transducer_stateless5/train.py | 2 +- .../pruned_transducer_stateless7/finetune.py | 2 +- .../ASR/pruned_transducer_stateless7/optim.py | 12 +- .../ASR/pruned_transducer_stateless7/train.py | 2 +- .../pruned_transducer_stateless7_ctc/train.py | 2 +- .../train.py | 2 +- .../train.py | 2 +- .../train2.py | 2 +- .../train.py | 2 +- .../ASR/pruned_transducer_stateless8/train.py | 2 +- .../ASR/streaming_conformer_ctc/conformer.py | 4 +- egs/librispeech/ASR/zipformer/decoder.py | 30 +- egs/librispeech/ASR/zipformer/joiner.py | 9 +- egs/librispeech/ASR/zipformer/onnx_decode.py | 4 +- egs/librispeech/ASR/zipformer/optim.py | 22 +- egs/librispeech/ASR/zipformer/profile.py | 12 +- egs/librispeech/ASR/zipformer/scaling.py | 715 ++++++++++-------- .../ASR/zipformer/streaming_decode.py | 57 +- egs/librispeech/ASR/zipformer/subsampling.py | 16 +- egs/librispeech/ASR/zipformer/train.py | 21 +- egs/librispeech/ASR/zipformer_mmi/train.py | 2 +- .../ASR/pruned_transducer_stateless5/train.py | 5 +- egs/multi_zh-hans/ASR/zipformer/train.py | 2 +- .../ASR/pruned_transducer_stateless5/train.py | 2 +- .../train.py | 4 +- egs/tedlium3/ASR/conformer_ctc2/train.py | 2 +- egs/tedlium3/ASR/zipformer/train.py | 2 +- .../pruned_transducer_stateless2/finetune.py | 2 +- .../ASR/pruned_transducer_stateless2/train.py | 2 +- .../ASR/pruned_transducer_stateless5/train.py | 2 +- egs/wenetspeech/ASR/zipformer/train.py | 2 +- .../ASR/pruned_transducer_stateless5/train.py | 2 +- .../ASR/pruned_transducer_stateless7/train.py | 2 +- 51 files changed, 511 insertions(+), 479 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index c9d9c4aa8..fa809b768 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -635,7 +635,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -800,7 +799,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py index d08908238..60f014c48 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py @@ -872,7 +872,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index 62e67530d..7c23041ca 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -1045,7 +1045,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py index cbb7db086..11671db92 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py @@ -1028,7 +1028,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7/train2.py index c30f6f960..057af297f 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train2.py @@ -1031,7 +1031,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py index 4e52f9573..3858bafd7 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -1019,7 +1019,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index 74bf68ccb..8c7448d4c 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -730,7 +730,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -919,7 +918,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index 47015cbe7..a354f761e 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -908,7 +908,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py index e57b5c859..30154291d 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py @@ -635,7 +635,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -800,7 +799,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py index 45d777922..8f09f1aa5 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py @@ -999,7 +999,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py index 8c8d9593b..9b67141c0 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/train.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py @@ -988,7 +988,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py index 4bd5b83a2..4aedeffe4 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py @@ -1019,7 +1019,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py index 18cb75c37..73fcd67aa 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1074,7 +1074,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py index bc4bcf253..4c866ddd8 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -1075,7 +1075,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index c5a05d349..ca21bd6bf 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -953,7 +953,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 6bb37b017..23ddb6bec 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -953,7 +953,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py index 36067510c..420dc1065 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py @@ -955,7 +955,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 77e06d3b7..a4899f7bd 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -811,7 +811,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 3b5a635e4..66dc5f991 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -1003,7 +1003,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index 3ee2b7d65..4e261dbc1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -1132,7 +1132,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index aa3cef338..8ab3589da 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -117,7 +117,7 @@ class BatchedOptimizer(Optimizer): yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for (stacked_params, _state, _names), batch in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer): parameters_names=None, show_dominant_parameters=True, ): - assert parameters_names is not None, ( "Please prepare parameters_names," "which is a List[List[str]]. Each List[str] is for a group" @@ -224,9 +223,7 @@ class ScaledAdam(BatchedOptimizer): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -325,7 +322,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: + for p, state, param_names in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -410,7 +407,7 @@ class ScaledAdam(BatchedOptimizer): from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: + for p, state, batch_param_names in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars @@ -426,7 +423,6 @@ class ScaledAdam(BatchedOptimizer): for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): - proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) @@ -1039,7 +1035,7 @@ def _test_scaled_adam(hidden_dim: int): # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 + # 512 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 2b4d51089..fac3706d2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -1028,7 +1028,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index b387968a9..d8fa08372 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -1052,7 +1052,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index 23fb6f497..25a1aa674 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -1042,7 +1042,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 99090b2c1..2d915ff87 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1029,7 +1029,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py index 9be629149..aa6c0668a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -1030,7 +1030,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index b494253d6..565dc7a16 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -1141,7 +1141,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index bee414292..3f271c5b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -1154,7 +1154,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py index be6fabf35..0b982f4bf 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py @@ -230,7 +230,9 @@ class Conformer(Transformer): x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask ) # (T, B, F) else: - x = self.encoder(x, pos_emb, src_key_padding_mask=src_key_padding_mask) # (T, B, F) + x = self.encoder( + x, pos_emb, src_key_padding_mask=src_key_padding_mask + ) # (T, B, F) if self.normalize_before: x = self.after_norm(x) diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index e8db988f6..e77e54118 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -61,10 +61,15 @@ class Decoder(nn.Module): ) # the balancers are to avoid any drift in the magnitude of the # embeddings, which would interact badly with parameter averaging. - self.balancer = Balancer(decoder_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, - min_abs=0.5, max_abs=1.0, - prob=0.05) + self.balancer = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) self.blank_id = blank_id @@ -81,10 +86,15 @@ class Decoder(nn.Module): groups=decoder_dim // 4, # group size == 4 bias=False, ) - self.balancer2 = Balancer(decoder_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, - min_abs=0.5, max_abs=1.0, - prob=0.05) + self.balancer2 = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ @@ -107,9 +117,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/zipformer/joiner.py b/egs/librispeech/ASR/zipformer/joiner.py index f03cc930e..dfb0a0057 100644 --- a/egs/librispeech/ASR/zipformer/joiner.py +++ b/egs/librispeech/ASR/zipformer/joiner.py @@ -52,12 +52,13 @@ class Joiner(nn.Module): Returns: Return a tensor of shape (N, T, s_range, C). """ - assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape) + assert encoder_out.ndim == decoder_out.ndim, ( + encoder_out.shape, + decoder_out.shape, + ) if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/zipformer/onnx_decode.py b/egs/librispeech/ASR/zipformer/onnx_decode.py index 2aca36ca9..356c2a830 100755 --- a/egs/librispeech/ASR/zipformer/onnx_decode.py +++ b/egs/librispeech/ASR/zipformer/onnx_decode.py @@ -303,7 +303,9 @@ def main(): for test_set, test_dl in zip(test_sets, test_dl): start_time = time.time() - results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table) + results, total_duration = decode_dataset( + dl=test_dl, model=model, token_table=token_table + ) end_time = time.time() elapsed_seconds = end_time - start_time rtf = elapsed_seconds / total_duration diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index abfb2092c..c9b76526c 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer): yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for (stacked_params, _state, _names), batch in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer): size_update_period=4, clipping_update_period=100, ): - defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -299,8 +298,8 @@ class ScaledAdam(BatchedOptimizer): # the input is groups of parameter or named parameter. for cur_group in iterable_or_groups: assert "named_params" in cur_group - name_list = [ x[0] for x in cur_group["named_params"] ] - p_list = [ x[1] for x in cur_group["named_params"] ] + name_list = [x[0] for x in cur_group["named_params"]] + p_list = [x[1] for x in cur_group["named_params"]] del cur_group["named_params"] cur_group["params"] = p_list param_groups.append(cur_group) @@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -428,7 +425,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: + for p, state, param_names in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -513,7 +510,7 @@ class ScaledAdam(BatchedOptimizer): from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: + for p, state, batch_param_names in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars @@ -529,7 +526,6 @@ class ScaledAdam(BatchedOptimizer): for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): - proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) @@ -667,8 +663,7 @@ class ScaledAdam(BatchedOptimizer): # We have to look at the trained model for parameters at or around the # param_max_rms, because sometimes they can indicate a problem with the # topology or settings. - scale_step = torch.minimum(scale_step, - (param_max_rms - param_rms) / param_rms) + scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) delta = state["delta"] # the factor of (1-beta1) relates to momentum. @@ -879,7 +874,8 @@ class Eden(LRScheduler): warmup_factor = ( 1.0 if self.batch >= self.warmup_batches - else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + else self.warmup_start + + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) # else 0.5 + 0.5 * (self.batch / self.warmup_batches) ) @@ -1111,7 +1107,7 @@ def _test_scaled_adam(hidden_dim: int): # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 + # 512 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) diff --git a/egs/librispeech/ASR/zipformer/profile.py b/egs/librispeech/ASR/zipformer/profile.py index b460b5338..57f44a90a 100755 --- a/egs/librispeech/ASR/zipformer/profile.py +++ b/egs/librispeech/ASR/zipformer/profile.py @@ -100,17 +100,13 @@ class Model(nn.Module): self.encoder_embed = encoder_embed self.encoder_proj = encoder_proj - def forward( - self, feature: Tensor, feature_lens: Tensor - ) -> Tuple[Tensor, Tensor]: + def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]: x, x_lens = self.encoder_embed(feature, feature_lens) src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = self.encoder( - x, x_lens, src_key_padding_mask - ) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C) logits = self.encoder_proj(encoder_out) @@ -168,9 +164,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7c98ef045..23fd279b3 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -25,6 +25,7 @@ import math import torch.nn as nn from torch import Tensor + def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: max_value = torch.max(x, y) diff = torch.abs(x - y) @@ -55,28 +56,34 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor: # for torch.jit.trace() return torch.logaddexp(x, y) + class PiecewiseLinear(object): """ Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] respectively. """ + def __init__(self, *args): assert len(args) >= 1, len(args) if len(args) == 1 and isinstance(args[0], PiecewiseLinear): self.pairs = list(args[0].pairs) else: - self.pairs = [ (float(x), float(y)) for x,y in args ] - for (x,y) in self.pairs: + self.pairs = [(float(x), float(y)) for x, y in args] + for (x, y) in self.pairs: assert isinstance(x, (float, int)), type(x) assert isinstance(y, (float, int)), type(y) for i in range(len(self.pairs) - 1): - assert self.pairs[i + 1][0] > self.pairs[i][0], (i, self.pairs[i], self.pairs[i + 1]) + assert self.pairs[i + 1][0] > self.pairs[i][0], ( + i, + self.pairs[i], + self.pairs[i + 1], + ) def __str__(self): # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' - return f'PiecewiseLinear({str(self.pairs)[1:-1]})' + return f"PiecewiseLinear({str(self.pairs)[1:-1]})" def __call__(self, x): if x <= self.pairs[0][0]: @@ -93,37 +100,36 @@ class PiecewiseLinear(object): assert False def __mul__(self, alpha): - return PiecewiseLinear( - * [(x, y * alpha) for x, y in self.pairs]) + return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) def __add__(self, x): if isinstance(x, (float, int)): - return PiecewiseLinear( - * [(p[0], p[1] + x) for p in self.pairs]) + return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) s, x = self.get_common_basis(x) return PiecewiseLinear( - * [(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] + ) def max(self, x): if isinstance(x, (float, int)): - x = PiecewiseLinear( (0, x) ) + x = PiecewiseLinear((0, x)) s, x = self.get_common_basis(x, include_crossings=True) return PiecewiseLinear( - * [(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) def min(self, x): if isinstance(x, float) or isinstance(x, int): - x = PiecewiseLinear( (0, x) ) + x = PiecewiseLinear((0, x)) s, x = self.get_common_basis(x, include_crossings=True) return PiecewiseLinear( - * [ (sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) def __eq__(self, other): return self.pairs == other.pairs - def get_common_basis(self, - p: 'PiecewiseLinear', - include_crossings: bool = False): + def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): """ Returns (self_mod, p_mod) which are equivalent piecewise linear functions to self and p, but with the same x values. @@ -135,28 +141,30 @@ class PiecewiseLinear(object): assert isinstance(p, PiecewiseLinear), type(p) # get sorted x-values without repetition. - x_vals = sorted(set([ x for x, _ in self.pairs ] + [ x for x, _ in p.pairs ])) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] + x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] if include_crossings: extra_x_vals = [] for i in range(len(x_vals) - 1): - if (y_vals1[i] > y_vals2[i]) != (y_vals1[i+1] > y_vals2[i+1]): + if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): # if the two lines in this subsegment potentially cross each other.. diff_cur = abs(y_vals1[i] - y_vals2[i]) - diff_next = abs(y_vals1[i+1] - y_vals2[i+1]) + diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) # `pos`, between 0 and 1, gives the relative x position, # with 0 being x_vals[i] and 1 being x_vals[i+1]. pos = diff_cur / (diff_cur + diff_next) - extra_x_val = x_vals[i] + pos * (x_vals[i+1] - x_vals[i]) + extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) extra_x_vals.append(extra_x_val) if len(extra_x_vals) > 0: x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] - return ( PiecewiseLinear(* zip(x_vals, y_vals1)), - PiecewiseLinear(* zip(x_vals, y_vals2)) ) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( + PiecewiseLinear(*zip(x_vals, y_vals1)), + PiecewiseLinear(*zip(x_vals, y_vals2)), + ) class ScheduledFloat(torch.nn.Module): @@ -176,9 +184,8 @@ class ScheduledFloat(torch.nn.Module): `default` is used when self.batch_count is not set or not in training mode or in torch.jit scripting mode. """ - def __init__(self, - *args, - default: float = 0.0): + + def __init__(self, *args, default: float = 0.0): super().__init__() # self.batch_count and self.name will be written to in the training loop. self.batch_count = None @@ -187,47 +194,55 @@ class ScheduledFloat(torch.nn.Module): self.schedule = PiecewiseLinear(*args) def extra_repr(self) -> str: - return f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' + return ( + f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" + ) def __float__(self): batch_count = self.batch_count - if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): + if ( + batch_count is None + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): return float(self.default) else: ans = self.schedule(self.batch_count) if random.random() < 0.0002: - logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") + logging.info( + f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" + ) return ans def __add__(self, x): if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule + x, - default=self.default) + return ScheduledFloat(self.schedule + x, default=self.default) else: - return ScheduledFloat(self.schedule + x.schedule, - default=self.default+x.default) + return ScheduledFloat( + self.schedule + x.schedule, default=self.default + x.default + ) def max(self, x): if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule.max(x), - default=self.default) + return ScheduledFloat(self.schedule.max(x), default=self.default) else: - return ScheduledFloat(self.schedule.max(x.schedule), - default=max(self.default, x.default)) + return ScheduledFloat( + self.schedule.max(x.schedule), default=max(self.default, x.default) + ) FloatLike = Union[float, ScheduledFloat] -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = (x_abs < min_abs) + is_too_small = x_abs < min_abs # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -242,6 +257,7 @@ class CutoffEstimator: p is the proportion of items that should be above the cutoff. """ + def __init__(self, p: float): self.p = p # total count of items @@ -255,7 +271,7 @@ class CutoffEstimator: """ Returns true if x is above the cutoff. """ - ans = (x > self.cutoff) + ans = x > self.cutoff self.count += 1 if ans: self.count_above += 1 @@ -263,7 +279,7 @@ class CutoffEstimator: delta_p = cur_p - self.p if (delta_p > 0) == ans: q = abs(delta_p) - self.cutoff = x * q + self.cutoff * (1-q) + self.cutoff = x * q + self.cutoff * (1 - q) return ans @@ -272,6 +288,7 @@ class SoftmaxFunction(torch.autograd.Function): Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ + @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -287,7 +304,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -306,17 +323,16 @@ def softmax(x: Tensor, dim: int): class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) return x @staticmethod @@ -328,15 +344,20 @@ class MaxEigLimiterFunction(torch.autograd.Function): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x ** 2).mean() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -348,8 +369,14 @@ class BiasNormFunction(torch.autograd.Function): # it can just store the returned value (chances are, this will also be needed for # some other reason, related to the next operation, so we can save memory). @staticmethod - def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int, - store_output_for_backprop: bool) -> Tensor: + def forward( + ctx, + x: Tensor, + bias: Tensor, + log_scale: Tensor, + channel_dim: int, + store_output_for_backprop: bool, + ) -> Tensor: assert bias.ndim == 1 if channel_dim < 0: channel_dim = channel_dim + x.ndim @@ -357,10 +384,16 @@ class BiasNormFunction(torch.autograd.Function): ctx.channel_dim = channel_dim for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) - scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() ans = x * scales - ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, - scales.detach(), bias.detach(), log_scale.detach()) + ctx.save_for_backward( + ans.detach() if store_output_for_backprop else x, + scales.detach(), + bias.detach(), + log_scale.detach(), + ) return ans @staticmethod @@ -376,7 +409,9 @@ class BiasNormFunction(torch.autograd.Function): log_scale.requires_grad = True with torch.enable_grad(): # recompute scales from x, bias and log_scale. - scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + scales = ( + torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() ans = x * scales ans.backward(gradient=ans_grad) return x.grad, bias.grad.flatten(), log_scale.grad, None, None @@ -412,14 +447,15 @@ class BiasNorm(torch.nn.Module): than the input of this module to be required to be stored for the backprop. """ + def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 1.0, - log_scale_min: float = -1.5, - log_scale_max: float = 1.5, - store_output_for_backprop: bool = False + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + log_scale: float = 1.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, + store_output_for_backprop: bool = False, ) -> None: super(BiasNorm, self).__init__() self.num_channels = num_channels @@ -442,23 +478,24 @@ class BiasNorm(torch.nn.Module): bias = self.bias for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) - scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * - self.log_scale.exp()) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * self.log_scale.exp() return x * scales - log_scale = limit_param_value(self.log_scale, - min=float(self.log_scale_min), - max=float(self.log_scale_max), - training=self.training) + log_scale = limit_param_value( + self.log_scale, + min=float(self.log_scale_min), + max=float(self.log_scale_max), + training=self.training, + ) - return BiasNormFunction.apply(x, self.bias, log_scale, - self.channel_dim, - self.store_output_for_backprop) + return BiasNormFunction.apply( + x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop + ) -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Linear: +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -477,15 +514,11 @@ def ScaledLinear(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Conv1d: +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -504,15 +537,11 @@ def ScaledConv1d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans -def ScaledConv2d(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Conv2d: +def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: """ Behaves like a constructor of a modified version of nn.Conv2d that gives an easy way to set the default initial parameter scale. @@ -532,9 +561,7 @@ def ScaledConv2d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans @@ -562,29 +589,36 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): Another option, if you want to do something like this, is to re-initialize the parameters. """ - def __init__(self, - channels: int, - kernel_size: int, - initial_scale: float = 1.0, - bias: bool = True): + + def __init__( + self, + channels: int, + kernel_size: int, + initial_scale: float = 1.0, + bias: bool = True, + ): super().__init__() assert kernel_size % 2 == 1 half_kernel_size = (kernel_size + 1) // 2 # will pad manually, on one side. - self.causal_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=half_kernel_size, - padding=0, - bias=True) + self.causal_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=half_kernel_size, + padding=0, + bias=True, + ) - self.chunkwise_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=kernel_size // 2, - bias=bias) + self.chunkwise_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=bias, + ) # first row is correction factors added to the scale near the left edge of the chunk, # second row is correction factors added to the scale near the right edge of the chunk, @@ -596,17 +630,15 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): self.causal_conv.weight[:] *= initial_scale self.chunkwise_conv.weight[:] *= initial_scale if bias: - torch.nn.init.uniform_(self.causal_conv.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_( + self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) - def forward(self, - x: Tensor, - chunk_size: int = -1) -> Tensor: + def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: """ - Forward function. Args: - x: a Tensor of shape (batch_size, channels, seq_len) - chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. + Forward function. Args: + x: a Tensor of shape (batch_size, channels, seq_len) + chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. """ (batch_size, num_channels, seq_len) = x.shape @@ -622,28 +654,32 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): x = torch.nn.functional.pad(x, (left_pad, right_pad)) - x_causal = self.causal_conv(x[..., :left_pad + seq_len]) + x_causal = self.causal_conv(x[..., : left_pad + seq_len]) assert x_causal.shape == (batch_size, num_channels, seq_len) x_chunk = x[..., left_pad:] num_chunks = x_chunk.shape[2] // chunk_size x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) - x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(batch_size * num_chunks, - num_channels, chunk_size) + x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( + batch_size * num_chunks, num_channels, chunk_size + ) x_chunk = self.chunkwise_conv(x_chunk) # does not change shape chunk_scale = self._get_chunk_scale(chunk_size) x_chunk = x_chunk * chunk_scale - x_chunk = x_chunk.reshape(batch_size, num_chunks, - num_channels, chunk_size).permute(0, 2, 1, 3) - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[..., :seq_len] + x_chunk = x_chunk.reshape( + batch_size, num_chunks, num_channels, chunk_size + ).permute(0, 2, 1, 3) + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ + ..., :seq_len + ] return x_chunk + x_causal def _get_chunk_scale(self, chunk_size: int): """Returns tensor of shape (num_channels, chunk_size) that will be used to - scale the output of self.chunkwise_conv.""" + scale the output of self.chunkwise_conv.""" left_edge = self.chunkwise_conv_scale[0] right_edge = self.chunkwise_conv_scale[1] if chunk_size < self.kernel_size: @@ -652,9 +688,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): else: t = chunk_size - self.kernel_size channels = left_edge.shape[0] - pad = torch.zeros(channels, t, - device=left_edge.device, - dtype=left_edge.dtype) + pad = torch.zeros( + channels, t, device=left_edge.device, dtype=left_edge.dtype + ) left_edge = torch.cat((left_edge, pad), dim=-1) right_edge = torch.cat((pad, right_edge), dim=-1) return 1.0 + (left_edge + right_edge) @@ -698,14 +734,14 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): class BalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min_mean: float, - max_mean: float, - min_rms: float, - max_rms: float, - grad_scale: float, - channel_dim: int, + ctx, + x: Tensor, + min_mean: float, + max_mean: float, + min_rms: float, + max_rms: float, + grad_scale: float, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim @@ -715,10 +751,8 @@ class BalancerFunction(torch.autograd.Function): return x @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None]: - x, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + (x,) = ctx.saved_tensors (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config try: @@ -727,8 +761,8 @@ class BalancerFunction(torch.autograd.Function): x = x.to(torch.float32) x = x.detach() x.requires_grad = True - mean_dims = [ i for i in range(x.ndim) if i != channel_dim ] - uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True) + mean_dims = [i for i in range(x.ndim) if i != channel_dim] + uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) mean = x.mean(dim=mean_dims, keepdim=True) stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() rms = uncentered_var.clamp(min=1.0e-20).sqrt() @@ -742,11 +776,16 @@ class BalancerFunction(torch.autograd.Function): rms_clamped = rms.clamp(min=min_rms, max=max_rms) r_loss = (rms_clamped / rms).log().abs() - loss = (m_loss + r_loss) + loss = m_loss + r_loss loss.backward(gradient=torch.ones_like(loss)) loss_grad = x.grad - loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) + loss_grad_rms = ( + (loss_grad**2) + .mean(dim=mean_dims, keepdim=True) + .sqrt() + .clamp(min=1.0e-20) + ) loss_grad = loss_grad * (grad_scale / loss_grad_rms) @@ -757,7 +796,9 @@ class BalancerFunction(torch.autograd.Function): x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) x_grad = x_grad_mod.to(x_grad.dtype) except Exception as e: - logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.") + logging.info( + f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." + ) return x_grad, None, None, None, None, None, None @@ -793,16 +834,17 @@ class Balancer(torch.nn.Module): on each forward(). This is done randomly to prevent all layers from doing it at the same time. """ + def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: FloatLike = 0.05, - max_positive: FloatLike = 0.95, - min_abs: FloatLike = 0.2, - max_abs: FloatLike = 100.0, - grad_scale: FloatLike = 0.04, - prob: Optional[FloatLike] = None, + self, + num_channels: int, + channel_dim: int, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + grad_scale: FloatLike = 0.04, + prob: Optional[FloatLike] = None, ): super().__init__() @@ -823,8 +865,11 @@ class Balancer(torch.nn.Module): self.grad_scale = grad_scale def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or not x.requires_grad or - (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): + if ( + torch.jit.is_scripting() + or not x.requires_grad + or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) + ): return _no_op(x) prob = float(self.prob) @@ -842,7 +887,7 @@ class Balancer(torch.nn.Module): eps = 1.0e-10 # eps is to prevent crashes if x is exactly 0 or 1. # we'll just end up returning a fairly large value. - return (math.log (1+x+eps) - math.log (1-x+eps)) / 2. + return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 def _approx_inverse_erf(x): # 1 / (sqrt(pi) * ln(2)), @@ -853,6 +898,7 @@ class Balancer(torch.nn.Module): # and math.erf(0.0407316414078772) = 0.045935330944660666, # which is pretty close to 0.05. return 0.8139535143 * _atanh(x) + # first convert x from the range 0..1 to the range -1..1 which the error # function returns x = -1 + (2 * x) @@ -873,8 +919,9 @@ class Balancer(torch.nn.Module): return _no_op(x) -def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, - name: str = None) -> Tensor: +def penalize_abs_values_gt( + x: Tensor, limit: float, penalty: float, name: str = None +) -> Tensor: """ Returns x unmodified, but in backprop will put a penalty for the excess of the absolute values of elements of x over the limit "limit". E.g. if @@ -910,13 +957,12 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] + x = x[:, :: dim + 1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, - num_groups: int): +def _whitening_metric(x: Tensor, num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -946,25 +992,22 @@ def _whitening_metric(x: Tensor, # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, - x: Tensor, - module: nn.Module) -> Tensor: + def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: ctx.save_for_backward(x) ctx.module = module return x @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors w = ctx.module try: @@ -976,8 +1019,10 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, w.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}") + logging.info( + f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" + ) if metric < float(w.whitening_limit): w.prob = w.min_prob @@ -986,22 +1031,27 @@ class WhiteningPenaltyFunction(torch.autograd.Function): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) + scale = w.grad_scale * ( + x_grad.to(torch.float32).norm() + / (penalty_grad.norm() + 1.0e-20) + ) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None except Exception as e: - logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.") + logging.info( + f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." + ) return x_grad, None class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: FloatLike, - prob: Union[float, Tuple[float,float]], - grad_scale: FloatLike): + self, + num_groups: int, + whitening_limit: FloatLike, + prob: Union[float, Tuple[float, float]], + grad_scale: FloatLike, + ): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -1033,10 +1083,9 @@ class Whiten(nn.Module): (self.min_prob, self.max_prob) = prob assert 0 < self.min_prob <= self.max_prob <= 1 self.prob = self.max_prob - self.name = None # will be set in training loop + self.name = None # will be set in training loop - def forward(self, - x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -1071,9 +1120,11 @@ class WithLoss(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device), None + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + None, + ) def with_loss(x, y, name): @@ -1118,20 +1169,21 @@ class LimitParamValue(torch.autograd.Function): @staticmethod def backward(ctx, x_grad: Tensor): - x, = ctx.saved_tensors + (x,) = ctx.saved_tensors # where x < ctx.min, ensure all grads are negative (this will tend to make # x more positive). - x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) + x_grad = x_grad * torch.where( + torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 + ) # where x > ctx.max, ensure all grads are positive (this will tend to make # x more negative). x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) return x_grad, None, None -def limit_param_value(x: Tensor, - min: float, max: float, - prob: float = 0.6, - training: bool = True): +def limit_param_value( + x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True +): # You apply this to (typically) an nn.Parameter during training to ensure that its # (elements mostly) stays within a supplied range. This is done by modifying the # gradients in backprop. @@ -1187,7 +1239,7 @@ class DoubleSwishFunction(torch.autograd.Function): y = x * s if requires_grad: - deriv = (y * (1 - s) + s) + deriv = y * (1 - s) + s # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 @@ -1197,7 +1249,9 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.044 ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1210,12 +1264,12 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = (d * ((ceil - floor) / 255.0) + floor) + d = d * ((ceil - floor) / 255.0) + floor return y_grad * d @@ -1239,9 +1293,7 @@ class Dropout2(nn.Module): self.p = p def forward(self, x: Tensor) -> Tensor: - return torch.nn.functional.dropout(x, - p=float(self.p), - training=self.training) + return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) class MulForDropout3(torch.autograd.Function): @@ -1259,7 +1311,7 @@ class MulForDropout3(torch.autograd.Function): @staticmethod @custom_bwd def backward(ctx, ans_grad): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors x_grad = ctx.alpha * ans_grad * (ans != 0) return x_grad, None, None @@ -1286,7 +1338,7 @@ class Dropout3(nn.Module): class SwooshLFunction(torch.autograd.Function): """ - swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 """ @staticmethod @@ -1308,13 +1360,15 @@ class SwooshLFunction(torch.autograd.Function): if not requires_grad: return y - y.backward(gradient = torch.ones_like(y)) + y.backward(gradient=torch.ones_like(y)) grad = x.grad floor = coeff ceil = 1.0 + coeff + 0.005 - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1328,20 +1382,19 @@ class SwooshLFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. coeff = -0.08 floor = coeff ceil = 1.0 + coeff + 0.005 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class SwooshL(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation. - """ + """Return Swoosh-L activation.""" if torch.jit.is_scripting() or torch.jit.is_tracing(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 @@ -1351,19 +1404,19 @@ class SwooshL(torch.nn.Module): return k2.swoosh_l(x) # return SwooshLFunction.apply(x) + class SwooshLOnnx(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation. - """ + """Return Swoosh-L activation.""" zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 class SwooshRFunction(torch.autograd.Function): """ - swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 - derivatives are between -0.08 and 0.92. + derivatives are between -0.08 and 0.92. """ @staticmethod @@ -1379,17 +1432,19 @@ class SwooshRFunction(torch.autograd.Function): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 if not requires_grad: return y - y.backward(gradient = torch.ones_like(y)) + y.backward(gradient=torch.ones_like(y)) grad = x.grad floor = -0.08 ceil = 0.925 - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1403,33 +1458,32 @@ class SwooshRFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.08 ceil = 0.925 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class SwooshR(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation. - """ + """Return Swoosh-R activation.""" if torch.jit.is_scripting() or torch.jit.is_tracing(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 if not x.requires_grad: return k2.swoosh_r_forward(x) else: return k2.swoosh_r(x) # return SwooshRFunction.apply(x) + class SwooshROnnx(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation. - """ + """Return Swoosh-R activation.""" zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp_onnx(zero, x - 1.) - 0.08 * x - 0.313261687 + return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687 # simple version of SwooshL that does not redefine the backprop, used in @@ -1437,7 +1491,7 @@ class SwooshROnnx(torch.nn.Module): def SwooshLForward(x: Tensor): x_offset = x - 4.0 log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) return log_sum - 0.08 * x - 0.035 @@ -1446,28 +1500,30 @@ def SwooshLForward(x: Tensor): def SwooshRForward(x: Tensor): x_offset = x - 1.0 log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) return log_sum - 0.08 * x - 0.313261687 class ActivationDropoutAndLinearFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, - x: Tensor, - weight: Tensor, - bias: Optional[Tensor], - activation: str, - dropout_p: float, - dropout_shared_dim: Optional[int]): + def forward( + ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + activation: str, + dropout_p: float, + dropout_shared_dim: Optional[int], + ): if dropout_p != 0.0: dropout_shape = list(x.shape) if dropout_shared_dim is not None: dropout_shape[dropout_shared_dim] = 1 # else it won't be very memory efficient. - dropout_mask = ((1.0 / (1.0 - dropout_p)) * - (torch.rand(*dropout_shape, - device=x.device, dtype=x.dtype) > dropout_p)) + dropout_mask = (1.0 / (1.0 - dropout_p)) * ( + torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p + ) else: dropout_mask = None @@ -1476,8 +1532,8 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): ctx.activation = activation forward_activation_dict = { - 'SwooshL': k2.swoosh_l_forward, - 'SwooshR': k2.swoosh_r_forward + "SwooshL": k2.swoosh_l_forward, + "SwooshR": k2.swoosh_r_forward, } # it will raise a KeyError if this fails. This will be an error. We let it # propagate to the user. @@ -1495,8 +1551,8 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): (x, weight, bias, dropout_mask) = saved forward_and_deriv_activation_dict = { - 'SwooshL': k2.swoosh_l_forward_and_deriv, - 'SwooshR': k2.swoosh_r_forward_and_deriv + "SwooshL": k2.swoosh_l_forward_and_deriv, + "SwooshR": k2.swoosh_r_forward_and_deriv, } # the following lines a KeyError if the activation is unrecognized. # This will be an error. We let it propagate to the user. @@ -1511,8 +1567,7 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): in_channels = y.shape[-1] g = ans_grad.reshape(-1, out_channels) - weight_deriv = torch.matmul(g.t(), - y.reshape(-1, in_channels)) + weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) y_deriv = torch.matmul(ans_grad, weight) bias_deriv = None if bias is None else g.sum(dim=0) x_deriv = y_deriv * func_deriv @@ -1525,71 +1580,76 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): class ActivationDropoutAndLinear(torch.nn.Module): """ - This merges an activation function followed by dropout and then a nn.Linear module; - it does so in a memory efficient way so that it only stores the input to the whole - module. If activation == SwooshL and dropout_shared_dim != None, this will be - equivalent to: - nn.Sequential(SwooshL(), - Dropout3(dropout_p, shared_dim=dropout_shared_dim), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=initial_scale)) - If dropout_shared_dim is None, the dropout would be equivalent to - Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout - mask is smaller. + This merges an activation function followed by dropout and then a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwooshL and dropout_shared_dim != None, this will be + equivalent to: + nn.Sequential(SwooshL(), + Dropout3(dropout_p, shared_dim=dropout_shared_dim), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + If dropout_shared_dim is None, the dropout would be equivalent to + Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout + mask is smaller. - Args: - in_channels: number of input channels, e.g. 256 - out_channels: number of output channels, e.g. 256 - bias: if true, have a bias - activation: the activation function, for now just support SwooshL. - dropout_p: the dropout probability or schedule (happens after nonlinearity). - dropout_shared_dim: the dimension, if any, across which the dropout mask is - shared (e.g. the time dimension). If None, this may be less memory - efficient if there are modules before this one that cache the input - for their backprop (e.g. Balancer or Whiten). + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwooshL. + dropout_p: the dropout probability or schedule (happens after nonlinearity). + dropout_shared_dim: the dimension, if any, across which the dropout mask is + shared (e.g. the time dimension). If None, this may be less memory + efficient if there are modules before this one that cache the input + for their backprop (e.g. Balancer or Whiten). """ - def __init__(self, - in_channels: int, - out_channels: int, - bias: bool = True, - activation: str = 'SwooshL', - dropout_p: FloatLike = 0.0, - dropout_shared_dim: Optional[int] = -1, - initial_scale: float = 1.0): + + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwooshL", + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + initial_scale: float = 1.0, + ): super().__init__() # create a temporary module of nn.Linear that we'll steal the # weights and bias from - l = ScaledLinear(in_channels, out_channels, - bias=bias, - initial_scale=initial_scale) + l = ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=initial_scale + ) self.weight = l.weight # register_parameter properly handles making it a parameter when l.bias # is None. I think there is some reason for doing it this way rather # than just setting it to None but I don't know what it is, maybe # something to do with exporting the module.. - self.register_parameter('bias', l.bias) + self.register_parameter("bias", l.bias) self.activation = activation self.dropout_p = dropout_p self.dropout_shared_dim = dropout_shared_dim - def forward(self, - x: Tensor): + def forward(self, x: Tensor): if torch.jit.is_scripting() or torch.jit.is_tracing(): - if self.activation == 'SwooshL': + if self.activation == "SwooshL": x = SwooshLForward(x) elif self.activation == "SwooshR": x = SwooshRForward(x) else: assert False, self.activation - return torch.nn.functional.linear(x, - self.weight, - self.bias) + return torch.nn.functional.linear(x, self.weight, self.bias) return ActivationDropoutAndLinearFunction.apply( - x, self.weight, self.bias, self.activation, - float(self.dropout_p), self.dropout_shared_dim) + x, + self.weight, + self.bias, + self.activation, + float(self.dropout_p), + self.dropout_shared_dim, + ) def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: @@ -1612,10 +1672,9 @@ def _test_whiten(): x.requires_grad = True - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale for _ in range(4): y = m(x) @@ -1656,9 +1715,7 @@ def _test_balancer_sign(): def _test_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = Balancer( @@ -1685,7 +1742,7 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = ((1.2-(-0.043637))/255.0) + tol = (1.2 - (-0.043637)) / 255.0 torch.autograd.gradcheck(m, x, atol=tol) # for self-test. @@ -1699,7 +1756,7 @@ def _test_swooshl_deriv(): x.requires_grad = True m = SwooshL() - tol = (1.0 / 255.0) + tol = 1.0 / 255.0 torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) # for self-test. @@ -1713,7 +1770,7 @@ def _test_swooshr_deriv(): x.requires_grad = True m = SwooshR() - tol = (1.0 / 255.0) + tol = 1.0 / 255.0 torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) # for self-test. @@ -1727,24 +1784,24 @@ def _test_softmax(): b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() + a.softmax(dim=1)[:, 0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() + softmax(b, dim=1)[:, 0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) def _test_piecewise_linear(): - p = PiecewiseLinear( (0, 10.0) ) + p = PiecewiseLinear((0, 10.0)) for x in [-100, 0, 100]: assert p(x) == 10.0 - p = PiecewiseLinear( (0, 10.0), (1, 0.0) ) - for x, y in [ (-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0) ]: + p = PiecewiseLinear((0, 10.0), (1, 0.0)) + for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: print("x, y = ", x, y) assert p(x) == y, (x, p(x), y) q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) - x_vals = [ -1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0 ] + x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] pq = p.max(q) for x in x_vals: y1 = max(p(x), q(x)) @@ -1757,7 +1814,7 @@ def _test_piecewise_linear(): assert abs(y1 - y2) < 0.001 pq = p + q for x in x_vals: - y1 = p(x) + q(x) + y1 = p(x) + q(x) y2 = pq(x) assert abs(y1 - y2) < 0.001 @@ -1772,15 +1829,22 @@ def _test_activation_dropout_and_linear(): # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() # internally, messing up the random state. for dropout_p in [0.0]: - for activation in ['SwooshL', 'SwooshR']: - m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(), - Dropout3(p=dropout_p, shared_dim=-1), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=0.5)) - m2 = ActivationDropoutAndLinear(in_channels, out_channels, - bias=bias, initial_scale=0.5, - activation=activation, - dropout_p=dropout_p) + for activation in ["SwooshL", "SwooshR"]: + m1 = nn.Sequential( + SwooshL() if activation == "SwooshL" else SwooshR(), + Dropout3(p=dropout_p, shared_dim=-1), + ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=0.5 + ), + ) + m2 = ActivationDropoutAndLinear( + in_channels, + out_channels, + bias=bias, + initial_scale=0.5, + activation=activation, + dropout_p=dropout_p, + ) with torch.no_grad(): m2.weight[:] = m1[2].weight if bias: @@ -1790,9 +1854,9 @@ def _test_activation_dropout_and_linear(): x1.requires_grad = True # TEMP. - assert torch.allclose(SwooshRFunction.apply(x1), - SwooshRForward(x1), - atol=1.0e-03) + assert torch.allclose( + SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 + ) x2 = x1.clone().detach() x2.requires_grad = True @@ -1805,21 +1869,24 @@ def _test_activation_dropout_and_linear(): y2 = m2(x2) y2.backward(gradient=y_grad) - print(f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}") + print( + f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" + ) print("y1 = ", y1) print("y2 = ", y2) assert torch.allclose(y1, y2, atol=0.02) - assert torch.allclose(m1[2].weight.grad, m2.weight.grad, - atol=1.0e-05) + assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) if bias: - assert torch.allclose(m1[2].bias.grad, m2.bias.grad, - atol=1.0e-05) + assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) print("x1.grad = ", x1.grad) print("x2.grad = ", x2.grad) def isclose(a, b): # return true if cosine similarity is > 0.9. - return (a * b).sum() > 0.9 * ((a**2).sum() * (b**2).sum()).sqrt() + return (a * b).sum() > 0.9 * ( + (a**2).sum() * (b**2).sum() + ).sqrt() + # the SwooshL() implementation has a noisy gradient due to 1-byte # storage of it. assert isclose(x1.grad, x2.grad) diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index 44ff392a3..904caf8af 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -282,9 +282,7 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: ) batch_states.append(cached_embed_left_pad) - processed_lens = torch.cat( - [state_list[i][-1] for i in range(batch_size)], dim=0 - ) + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) batch_states.append(processed_lens) return batch_states @@ -322,9 +320,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: for layer in range(tot_num_layers): layer_offset = layer * 6 # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk( - chunks=batch_size, dim=1 - ) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( chunks=batch_size, dim=1 @@ -355,9 +351,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: cached_conv2_list[i], ] - cached_embed_left_pad_list = batch_states[-2].chunk( - chunks=batch_size, dim=0 - ) + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) for i in range(batch_size): state_list[i].append(cached_embed_left_pad_list[i]) @@ -380,11 +374,7 @@ def streaming_forward( Returns encoder outputs, output lengths, and updated states. """ cached_embed_left_pad = states[-2] - ( - x, - x_lens, - new_cached_embed_left_pad, - ) = model.encoder_embed.streaming_forward( + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( x=features, x_lens=feature_lens, cached_left_pad=cached_embed_left_pad, @@ -404,9 +394,7 @@ def streaming_forward( new_processed_lens = processed_lens + x_lens # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat( - [processed_mask, src_key_padding_mask], dim=1 - ) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_states = states[:-2] @@ -494,9 +482,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = torch.tensor(processed_lens, device=device) processed_lens = processed_lens + encoder_out_lens @@ -517,9 +503,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = unstack_states(new_states) @@ -577,9 +561,7 @@ def decode_dataset( decode_streams = [] for num, cut in enumerate(cuts): # each utterance has a DecodeStream. - initial_states = get_init_states( - model=model, batch_size=1, device=device - ) + initial_states = get_init_states(model=model, batch_size=1, device=device) decode_stream = DecodeStream( params=params, cut_id=cut.id, @@ -649,9 +631,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -684,8 +664,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -718,9 +697,7 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" assert params.causal, params.causal - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." @@ -760,9 +737,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -789,9 +766,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 6532ddccb..d16d87bac 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -107,9 +107,7 @@ class ConvNeXt(nn.Module): if layerdrop_rate != 0.0: batch_size = x.shape[0] mask = ( - torch.rand( - (batch_size, 1, 1, 1), dtype=x.dtype, device=x.device - ) + torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate ) else: @@ -278,9 +276,7 @@ class Conv2dSubsampling(nn.Module): # many copies of this extra gradient term. self.out_whiten = Whiten( num_groups=1, - whitening_limit=ScheduledFloat( - (0.0, 4.0), (20000.0, 8.0), default=4.0 - ), + whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0), prob=(0.025, 0.25), grad_scale=0.02, ) @@ -331,7 +327,7 @@ class Conv2dSubsampling(nn.Module): with warnings.catch_warnings(): warnings.simplefilter("ignore") x_lens = (x_lens - 7) // 2 - assert x.size(1) == x_lens.max().item() , (x.size(1), x_lens.max()) + assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) return x, x_lens @@ -403,8 +399,8 @@ class Conv2dSubsampling(nn.Module): left_pad = self.convnext.padding[0] freq = self.out_width channels = self.layer3_channels - cached_embed_left_pad = torch.zeros( - batch_size, channels, left_pad, freq - ).to(device) + cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to( + device + ) return cached_embed_left_pad diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index bc3e9c1ba..7009f3346 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -604,11 +604,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module: - assert ( - params.use_transducer or params.use_ctc - ), (f"At least one of them should be True, " + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}") + f"params.use_ctc={params.use_ctc}" + ) encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) @@ -808,17 +808,16 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss += ( - simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss @@ -1166,7 +1165,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index c1b3ea3e0..4b50acdde 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -981,7 +981,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py index a68702776..48468cfbd 100755 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py @@ -746,7 +746,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -966,7 +965,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1019,7 +1018,6 @@ def run(rank, world_size, args): scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) @@ -1118,7 +1116,6 @@ def scan_pessimistic_batches_for_oom( # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _, _ = compute_loss( params=params, model=model, diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index 4f2d728be..c1bbd2ee8 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -1164,7 +1164,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py index 417515968..d03970265 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py @@ -915,7 +915,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py index d80e0147c..aee3972cd 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -69,7 +69,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer -from icefall import diagnostics, byte_encode, tokenize_by_CJK_char +from icefall import byte_encode, diagnostics, tokenize_by_CJK_char from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -1018,7 +1018,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tedlium3/ASR/conformer_ctc2/train.py b/egs/tedlium3/ASR/conformer_ctc2/train.py index 42e4c010a..fc3e3b2d9 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/train.py +++ b/egs/tedlium3/ASR/conformer_ctc2/train.py @@ -905,7 +905,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py index 9271c8438..33d03908c 100755 --- a/egs/tedlium3/ASR/zipformer/train.py +++ b/egs/tedlium3/ASR/zipformer/train.py @@ -1126,7 +1126,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index e703100a9..82bc882bd 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -886,7 +886,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index 48b347b64..49977e01b 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -851,7 +851,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py index 8e1b12dba..931e699d9 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py @@ -985,7 +985,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py index 83dbfa22f..b1557dedb 100755 --- a/egs/wenetspeech/ASR/zipformer/train.py +++ b/egs/wenetspeech/ASR/zipformer/train.py @@ -1128,7 +1128,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py index 5b5ac17be..a6fa46b17 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py @@ -1001,7 +1001,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py index f8dd7b287..8c53972fd 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py @@ -993,7 +993,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) From ef5da4824d033153f118556bd8407ace848061d2 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 24 Sep 2023 17:31:01 +0800 Subject: [PATCH 043/113] formatted the entire LibriSpeech recipe (#1270) * formatted the entire librispeech recipe * minor updates --- egs/librispeech/ASR/conformer_ctc/train.py | 1 - egs/librispeech/ASR/local/download_lm.py | 1 + .../ASR/long_file_recog/beam_search.py | 4 +- .../ASR/long_file_recog/merge_chunks.py | 1 - .../ASR/long_file_recog/recognize.py | 1 + .../ASR/pruned2_knowledge/optim.py | 1 - .../beam_search.py | 12 +- .../ASR/pruned_transducer_stateless2/optim.py | 1 - .../pruned_transducer_stateless2/scaling.py | 1 - .../pruned_transducer_stateless6/vq_utils.py | 1 - .../pruned_transducer_stateless7/alignment.py | 2 +- .../ASR/streaming_conformer_ctc/train.py | 1 - egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 1 - egs/librispeech/ASR/transducer/train.py | 1 - egs/librispeech/ASR/transducer_lstm/train.py | 1 - egs/librispeech/ASR/zipformer/scaling.py | 2 +- icefall/__init__.py | 8 +- icefall/context_graph.py | 1 - icefall/diagnostics.py | 206 ++++++++++-------- icefall/profiler.py | 61 ++---- icefall/rnn_lm/check-onnx-streaming.py | 1 - icefall/rnn_lm/train.py | 1 - icefall/shared/make_kn_lm.py | 2 - icefall/transformer_lm/model.py | 1 - requirements-ci.txt | 1 + requirements.txt | 1 + 26 files changed, 144 insertions(+), 171 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 99fe64793..828106f41 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -557,7 +557,6 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index da1648d06..5a36ff2a9 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -43,6 +43,7 @@ from pathlib import Path from tqdm.auto import tqdm + # This function is copied from lhotse def tqdm_urlretrieve_hook(t): """Wraps tqdm instance. diff --git a/egs/librispeech/ASR/long_file_recog/beam_search.py b/egs/librispeech/ASR/long_file_recog/beam_search.py index f8c31861c..b65e9d40a 100644 --- a/egs/librispeech/ASR/long_file_recog/beam_search.py +++ b/egs/librispeech/ASR/long_file_recog/beam_search.py @@ -236,7 +236,7 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -507,7 +507,7 @@ def modified_beam_search( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] diff --git a/egs/librispeech/ASR/long_file_recog/merge_chunks.py b/egs/librispeech/ASR/long_file_recog/merge_chunks.py index d38d9c86a..9e31e00d5 100755 --- a/egs/librispeech/ASR/long_file_recog/merge_chunks.py +++ b/egs/librispeech/ASR/long_file_recog/merge_chunks.py @@ -162,7 +162,6 @@ def merge_chunks( futures = [] with ThreadPoolExecutor(max_workers=1) as executor: - for cut in cuts_chunk: cur_rec_id = cut.recording.id if len(cut_list) == 0: diff --git a/egs/librispeech/ASR/long_file_recog/recognize.py b/egs/librispeech/ASR/long_file_recog/recognize.py index 96c83f859..466253446 100755 --- a/egs/librispeech/ASR/long_file_recog/recognize.py +++ b/egs/librispeech/ASR/long_file_recog/recognize.py @@ -264,6 +264,7 @@ def decode_dataset( - timestamps of reference transcript - timestamps of predicted result """ + # Background worker to add alignemnt and save cuts to disk. def _save_worker( cuts: List[Cut], diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py index 76cd4e11e..9f287ce70 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/optim.py +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -66,7 +66,6 @@ class Eve(Optimizer): weight_decay=1e-3, target_rms=0.1, ): - if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 3298568a3..7fcd242fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -719,7 +719,7 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -1019,7 +1019,7 @@ def modified_beam_search( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -1227,7 +1227,7 @@ def modified_beam_search_lm_rescore( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -1427,7 +1427,7 @@ def modified_beam_search_lm_rescore_LODR( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -2608,7 +2608,6 @@ def modified_beam_search_LODR( context_score = 0 new_context_state = None if context_graph is None else hyp.context_state if new_token not in (blank_id, unk_id): - if context_graph is not None: ( context_score, @@ -2758,7 +2757,7 @@ def modified_beam_search_lm_shallow_fusion( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] # get batch @@ -2900,7 +2899,6 @@ def modified_beam_search_lm_shallow_fusion( new_token = topk_token_indexes[k] new_timestamp = hyp.timestamp[:] if new_token not in (blank_id, unk_id): - ys.append(new_token) new_timestamp.append(t) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 2d7f557ad..f54bc2709 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -66,7 +66,6 @@ class Eve(Optimizer): weight_decay=1e-3, target_rms=0.1, ): - if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 963ebdc2d..91d64c1df 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -528,7 +528,6 @@ class ScaledLSTM(nn.LSTM): return with torch.cuda.device_of(first_fw): - # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is # an inplace operation on self._flat_weights with torch.no_grad(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 14ff86f23..3bca7db2c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -56,7 +56,6 @@ class CodebookIndexExtractor: """ def __init__(self, params: AttributeDict): - self.params = params params.subsets = ["clean-100"] if self.params.full_libri: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py index 76cd56bbb..bfb5fe609 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py @@ -111,7 +111,7 @@ def batch_force_alignment( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py index bb55ed6bb..14d7274c2 100755 --- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py @@ -543,7 +543,6 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 0aa1587ba..90245ed46 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -463,7 +463,6 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}" ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index f2a09346c..9ac6b7d03 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -513,7 +513,6 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index a6f2bd08c..92134116c 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -517,7 +517,6 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 23fd279b3..c0f1e3087 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -70,7 +70,7 @@ class PiecewiseLinear(object): self.pairs = list(args[0].pairs) else: self.pairs = [(float(x), float(y)) for x, y in args] - for (x, y) in self.pairs: + for x, y in self.pairs: assert isinstance(x, (float, int)), type(x) assert isinstance(y, (float, int)), type(y) diff --git a/icefall/__init__.py b/icefall/__init__.py index 05e2b408c..b1e4313e9 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -1,12 +1,6 @@ # isort:skip_file -from . import ( - checkpoint, - decode, - dist, - env, - utils -) +from . import checkpoint, decode, dist, env, utils from .byte_utils import ( byte_decode, diff --git a/icefall/context_graph.py b/icefall/context_graph.py index 01836df04..0b7c42c0b 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -227,7 +227,6 @@ class ContextGraph: filename: Optional[str] = "", symbol_table: Optional[Dict[int, str]] = None, ) -> "Digraph": # noqa - """Visualize a ContextGraph via graphviz. Render ContextGraph as an image via graphviz, and return the Digraph object; diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 98870684e..700dc1500 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -23,6 +23,7 @@ from typing import Optional, Tuple, List import torch from torch import Tensor, nn + class TensorDiagnosticOptions(object): """Options object for tensor diagnostics: @@ -77,11 +78,11 @@ def get_tensor_stats( elif stats_type == "abs": x = x.abs() elif stats_type == "rms": - x = x ** 2 + x = x**2 elif stats_type == "positive": x = (x > 0).to(dtype=torch.float) else: - assert stats_type in [ "value", "max", "min" ] + assert stats_type in ["value", "max", "min"] sum_dims = [d for d in range(x.ndim) if d != dim] if len(sum_dims) > 0: @@ -121,10 +122,10 @@ class TensorDiagnostic(object): self.class_name = None # will assign in accumulate() self.stats = None # we'll later assign a list to self.stats. - # It's a list of dicts, indexed by dim (i.e. by the - # axis of the tensor). The dicts, in turn, are - # indexed by `stats-type` which are strings in - # ["abs", "max", "min", "positive", "value", "rms"]. + # It's a list of dicts, indexed by dim (i.e. by the + # axis of the tensor). The dicts, in turn, are + # indexed by `stats-type` which are strings in + # ["abs", "max", "min", "positive", "value", "rms"]. # scalar_stats contains some analysis of the activations and gradients, self.scalar_stats = None @@ -139,7 +140,6 @@ class TensorDiagnostic(object): # only adding a new element to the list if there was a different dim. # if the string in the key is "eigs", if we detect a length mismatch we put None as the value. - def accumulate(self, x, class_name: Optional[str] = None): """ Accumulate tensors. @@ -193,17 +193,12 @@ class TensorDiagnostic(object): done = True break if not done: - if ( - this_dim_stats[stats_type] != [] - and stats_type == "eigs" - ): + if this_dim_stats[stats_type] != [] and stats_type == "eigs": # >1 size encountered on this dim, e.g. it's a batch or time dimension, # don't accumulat "eigs" stats type, it uses too much memory this_dim_stats[stats_type] = None else: - this_dim_stats[stats_type].append( - TensorAndCount(stats, count) - ) + this_dim_stats[stats_type].append(TensorAndCount(stats, count)) def print_diagnostics(self): """Print diagnostics for each dimension of the tensor.""" @@ -220,8 +215,11 @@ class TensorDiagnostic(object): for r, v in zip(rms_stats_list, value_stats_list): stddev_stats_list.append( # r.count and v.count should be the same, but we don't check this. - TensorAndCount(r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20), - r.count)) + TensorAndCount( + r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20), + r.count, + ) + ) this_dim_stats["stddev"] = stddev_stats_list for stats_type, stats_list in this_dim_stats.items(): @@ -232,7 +230,6 @@ class TensorDiagnostic(object): assert stats_type == "eigs" continue - def get_count(count): return 1 if stats_type in ["max", "min"] else count @@ -250,22 +247,20 @@ class TensorDiagnostic(object): eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() except: # noqa - print( - "Error getting eigenvalues, trying another method." - ) + print("Error getting eigenvalues, trying another method.") eigs, _ = torch.eig(stats) stats = eigs.norm(dim=1).sqrt() # sqrt so it reflects data magnitude, like stddev- not variance - if stats_type in [ "rms", "stddev" ]: + if stats_type in ["rms", "stddev"]: # we stored the square; after aggregation we need to take sqrt. stats = stats.sqrt() # if `summarize` we print percentiles of the stats; else, # we print out individual elements. - summarize = ( - len(stats_list) > 1 - ) or self.opts.dim_is_summarized(stats.numel()) + summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized( + stats.numel() + ) if summarize: # usually `summarize` will be true # print out percentiles. stats = stats.sort()[0] @@ -282,15 +277,15 @@ class TensorDiagnostic(object): ans = stats.tolist() ans = ["%.2g" % x for x in ans] ans = "[" + " ".join(ans) + "]" - if stats_type in [ "value", "rms", "stddev", "eigs" ]: + if stats_type in ["value", "rms", "stddev", "eigs"]: # This norm is useful because it is strictly less than the largest # sqrt(eigenvalue) of the variance, which we print out, and shows, # speaking in an approximate way, how much of that largest eigenvalue # can be attributed to the mean of the distribution. - norm = (stats ** 2).sum().sqrt().item() + norm = (stats**2).sum().sqrt().item() ans += f", norm={norm:.2g}" mean = stats.mean().item() - rms = (stats ** 2).mean().sqrt().item() + rms = (stats**2).mean().sqrt().item() ans += f", mean={mean:.3g}, rms={rms:.3g}" # OK, "ans" contains the actual stats, e.g. @@ -298,11 +293,11 @@ class TensorDiagnostic(object): sizes = [x.tensor.shape[0] for x in stats_list] size_str = ( - f"{sizes[0]}" - if len(sizes) == 1 - else f"{min(sizes)}..{max(sizes)}" + f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}" + ) + maybe_class_name = ( + f" type={self.class_name}," if self.class_name is not None else "" ) - maybe_class_name = f" type={self.class_name}," if self.class_name is not None else "" print( f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}" ) @@ -330,7 +325,6 @@ class ScalarDiagnostic(object): self.sum_gradsq = None self.sum_abs_grad = None - def accumulate_input(self, x: Tensor, class_name: Optional[str] = None): """ Called in forward pass. @@ -347,8 +341,10 @@ class ScalarDiagnostic(object): limit = 10 if len(self.saved_inputs) > limit: - print(f"ERROR: forward pass called for this module over {limit} times with no backward pass. " - f" Will not accumulate scalar stats.") + print( + f"ERROR: forward pass called for this module over {limit} times with no backward pass. " + f" Will not accumulate scalar stats." + ) self.is_ok = False return self.saved_inputs.append(x) @@ -359,11 +355,15 @@ class ScalarDiagnostic(object): if self.is_forward_pass: self.is_forward_pass = False - last_shape = 'n/a' if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape + last_shape = ( + "n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape + ) if len(self.saved_inputs) == 0 or grad.shape != last_shape: - print(f"ERROR: shape mismatch or no forward activation present when backward " - f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}" - f", shape-of-last-saved-input={last_shape}") + print( + f"ERROR: shape mismatch or no forward activation present when backward " + f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}" + f", shape-of-last-saved-input={last_shape}" + ) self.is_ok = False return @@ -384,11 +384,19 @@ class ScalarDiagnostic(object): self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side) # integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1] - self.counts = torch.zeros(2 * num_ticks_per_side, dtype=torch.long, device=x.device) - self.sum_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device) + self.counts = torch.zeros( + 2 * num_ticks_per_side, dtype=torch.long, device=x.device + ) + self.sum_grad = torch.zeros( + 2 * num_ticks_per_side, dtype=torch.double, device=x.device + ) # sum_gradsq is for getting error bars. - self.sum_gradsq = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device) - self.sum_abs_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device) + self.sum_gradsq = torch.zeros( + 2 * num_ticks_per_side, dtype=torch.double, device=x.device + ) + self.sum_abs_grad = torch.zeros( + 2 * num_ticks_per_side, dtype=torch.double, device=x.device + ) # this will round down. x = (x / self.tick_scale).to(torch.long) @@ -397,20 +405,21 @@ class ScalarDiagnostic(object): self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x)) self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double)) - self.sum_gradsq.index_add_(dim=0, index=x, source=(grad*grad).to(torch.double)) + self.sum_gradsq.index_add_( + dim=0, index=x, source=(grad * grad).to(torch.double) + ) self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double)) - def print_diagnostics(self): """Print diagnostics.""" if self.is_ok is False or self.counts is None: print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}") return - counts = self.counts.to('cpu') - sum_grad = self.sum_grad.to(device='cpu', dtype=torch.float32) - sum_gradsq = self.sum_gradsq.to(device='cpu', dtype=torch.float32) - sum_abs_grad = self.sum_abs_grad.to(device='cpu', dtype=torch.float32) + counts = self.counts.to("cpu") + sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32) + sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32) + sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32) counts_cumsum = counts.cumsum(dim=0) counts_tot = counts_cumsum[-1] @@ -433,19 +442,22 @@ class ScalarDiagnostic(object): bin_abs_grad = torch.zeros(num_bins) bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad) - avg_grad = (bin_grad / bin_counts) + avg_grad = bin_grad / bin_counts avg_grad_stddev = (bin_gradsq / bin_counts).sqrt() - bin_boundary_counts = torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin + bin_boundary_counts = ( + torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin + ) bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts) # boundaries are the "x" values between the bins, e.g. corresponding to the # locations of percentiles of the distribution. num_ticks_per_side = counts.numel() // 2 bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale - bin_grad = bin_grad / (bin_counts + 1) - bin_conf_interval = bin_gradsq.sqrt() / (bin_counts + 1) # consider this a standard deviation. + bin_conf_interval = bin_gradsq.sqrt() / ( + bin_counts + 1 + ) # consider this a standard deviation. # bin_grad / bin_abs_grad will give us a sense for how important in a practical sense, # the gradients are. bin_abs_grad = bin_abs_grad / (bin_counts + 1) @@ -458,8 +470,9 @@ class ScalarDiagnostic(object): x = "[" + " ".join(x) + "]" return x - - maybe_class_name = f" type={self.class_name}," if self.class_name is not None else "" + maybe_class_name = ( + f" type={self.class_name}," if self.class_name is not None else "" + ) print( f"module={self.name},{maybe_class_name} bin-boundaries={tensor_to_str(bin_boundaries)}, " @@ -467,7 +480,6 @@ class ScalarDiagnostic(object): ) - class ModelDiagnostic(object): """This class stores diagnostics for all tensors in the torch.nn.Module. @@ -485,9 +497,8 @@ class ModelDiagnostic(object): self.opts = opts self.diagnostics = dict() - def __getitem__(self, name: str): - T = ScalarDiagnostic if name[-7:] == '.scalar' else TensorDiagnostic + T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic if name not in self.diagnostics: self.diagnostics[name] = T(self.opts, name) return self.diagnostics[name] @@ -502,18 +513,19 @@ def get_class_name(module: nn.Module): ans = type(module).__name__ # we put the below in try blocks in case anyone is using a different version of these modules that # might have different member names. - if ans == 'Balancer' or ans == 'ActivationBalancer': + if ans == "Balancer" or ans == "ActivationBalancer": try: - ans += f'[{float(module.min_positive)},{float(module.max_positive)},{float(module.min_abs)},{float(module.max_abs)}]' + ans += f"[{float(module.min_positive)},{float(module.max_positive)},{float(module.min_abs)},{float(module.max_abs)}]" except: pass - elif ans == 'AbsValuePenalizer': + elif ans == "AbsValuePenalizer": try: - ans += f'[{module.limit}]' + ans += f"[{module.limit}]" except: pass return ans + def attach_diagnostics( model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None ) -> ModelDiagnostic: @@ -538,73 +550,85 @@ def attach_diagnostics( if name == "": name = "" - - # Setting model_diagnostic=ans and n=name below, instead of trying to # capture the variables, ensures that we use the current values. # (this matters for `name`, since the variable gets overwritten). # These closures don't really capture by value, only by # "the final value the variable got in the function" :-( - def forward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name - ): + def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] - if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.output"].accumulate(_output, - class_name=get_class_name(_module)) + if isinstance(_output, Tensor) and _output.dtype in ( + torch.float32, + torch.float16, + torch.float64, + ): + _model_diagnostic[f"{_name}.output"].accumulate( + _output, class_name=get_class_name(_module) + ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if o.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, - class_name=get_class_name(_module)) + if o.dtype in (torch.float32, torch.float16, torch.float64): + _model_diagnostic[f"{_name}.output[{i}]"].accumulate( + o, class_name=get_class_name(_module) + ) - def backward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name - ): + def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] - if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.grad"].accumulate(_output, - class_name=get_class_name(_module)) + if isinstance(_output, Tensor) and _output.dtype in ( + torch.float32, + torch.float16, + torch.float64, + ): + _model_diagnostic[f"{_name}.grad"].accumulate( + _output, class_name=get_class_name(_module) + ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if o.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, - class_name=get_class_name(_module)) - + if o.dtype in (torch.float32, torch.float16, torch.float64): + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate( + o, class_name=get_class_name(_module) + ) module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) - if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish", "Swoosh"]: + if type(module).__name__ in [ + "Sigmoid", + "Tanh", + "ReLU", + "TanSwish", + "Swish", + "DoubleSwish", + "Swoosh", + ]: # For these specific module types, accumulate some additional diagnostics # that can help us improve the activation function. These require a lot of memory, # to save the forward activations, so limit this to some select classes. # Note: this will not work correctly for all model types. def scalar_forward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name + _module, _input, _output, _model_diagnostic=ans, _name=name ): if isinstance(_input, tuple): - _input, = _input + (_input,) = _input assert isinstance(_input, Tensor) - _model_diagnostic[f"{_name}.scalar"].accumulate_input(_input, - class_name=get_class_name(_module)) + _model_diagnostic[f"{_name}.scalar"].accumulate_input( + _input, class_name=get_class_name(_module) + ) def scalar_backward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name + _module, _input, _output, _model_diagnostic=ans, _name=name ): if isinstance(_output, tuple): - _output, = _output + (_output,) = _output assert isinstance(_output, Tensor) _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output) module.register_forward_hook(scalar_forward_hook) module.register_backward_hook(scalar_backward_hook) - - for name, parameter in model.named_parameters(): def param_backward_hook( diff --git a/icefall/profiler.py b/icefall/profiler.py index dc76ebebc..49e138579 100644 --- a/icefall/profiler.py +++ b/icefall/profiler.py @@ -70,25 +70,17 @@ class FlopsProfiler(object): module_flop_count.append([]) if not hasattr(module, "__pre_hook_handle__"): - module.__pre_hook_handle__ = module.register_forward_pre_hook( - pre_hook - ) + module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook) def post_hook(module, input, output): if module_flop_count: - module.__flops__ += sum( - [elem[1] for elem in module_flop_count[-1]] - ) + module.__flops__ += sum([elem[1] for elem in module_flop_count[-1]]) module_flop_count.pop() if not hasattr(module, "__post_hook_handle__"): - module.__post_hook_handle__ = module.register_forward_hook( - post_hook - ) + module.__post_hook_handle__ = module.register_forward_hook(post_hook) - self.model.apply( - partial(register_module_hooks, ignore_list=ignore_list) - ) + self.model.apply(partial(register_module_hooks, ignore_list=ignore_list)) self.started = True self.func_patched = True @@ -194,9 +186,7 @@ def _prelu_flops_compute(input: Tensor, weight: Tensor): return input.numel() -def _elu_flops_compute( - input: Tensor, alpha: float = 1.0, inplace: bool = False -): +def _elu_flops_compute(input: Tensor, alpha: float = 1.0, inplace: bool = False): return input.numel() @@ -259,9 +249,7 @@ def _conv_flops_compute( output_dims.append(output_dim) filters_per_channel = out_channels // groups - conv_per_position_macs = ( - int(_prod(kernel_dims)) * in_channels * filters_per_channel - ) + conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel active_elements_count = batch_size * int(_prod(output_dims)) overall_conv_macs = conv_per_position_macs * active_elements_count overall_conv_flops = 2 * overall_conv_macs @@ -297,7 +285,6 @@ def _conv_trans_flops_compute( output_dims = [] for idx, input_dim in enumerate(input_dims): - output_dim = ( input_dim + 2 * paddings[idx] @@ -310,9 +297,7 @@ def _conv_trans_flops_compute( dilations = dilation if type(dilation) is tuple else (dilation, dilation) filters_per_channel = out_channels // groups - conv_per_position_macs = ( - int(_prod(kernel_dims)) * in_channels * filters_per_channel - ) + conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel active_elements_count = batch_size * int(_prod(input_dims)) overall_conv_macs = conv_per_position_macs * active_elements_count overall_conv_flops = 2 * overall_conv_macs @@ -389,9 +374,7 @@ def _upsample_flops_compute(input, **kwargs): else: return int(size), 0 scale_factor = kwargs.get("scale_factor", None) - assert ( - scale_factor is not None - ), "either size or scale_factor should be defined" + assert scale_factor is not None, "either size or scale_factor should be defined" flops = input.numel() if isinstance(scale_factor, tuple) and len(scale_factor) == len(input): flops * int(_prod(scale_factor)) @@ -593,12 +576,8 @@ def _patch_functionals(): F.embedding = wrapFunc(F.embedding, _embedding_flops_compute) # swoosh functions in k2 - k2.swoosh_l_forward = wrapFunc( - k2.swoosh_l_forward, _k2_swoosh_flops_compute - ) - k2.swoosh_r_forward = wrapFunc( - k2.swoosh_r_forward, _k2_swoosh_flops_compute - ) + k2.swoosh_l_forward = wrapFunc(k2.swoosh_l_forward, _k2_swoosh_flops_compute) + k2.swoosh_r_forward = wrapFunc(k2.swoosh_r_forward, _k2_swoosh_flops_compute) k2.swoosh_l = wrapFunc(k2.swoosh_l, _k2_swoosh_flops_compute) k2.swoosh_r = wrapFunc(k2.swoosh_r, _k2_swoosh_flops_compute) @@ -612,9 +591,7 @@ def _patch_tensor_methods(): torch.Tensor.bmm = wrapFunc(torch.Tensor.bmm, _matmul_flops_compute) torch.addmm = wrapFunc(torch.addmm, _addmm_flops_compute) - torch.Tensor.addmm = wrapFunc( - torch.Tensor.addmm, _tensor_addmm_flops_compute - ) + torch.Tensor.addmm = wrapFunc(torch.Tensor.addmm, _tensor_addmm_flops_compute) torch.mul = wrapFunc(torch.mul, _mul_flops_compute) torch.Tensor.mul = wrapFunc(torch.Tensor.mul, _mul_flops_compute) @@ -631,14 +608,10 @@ def _patch_tensor_methods(): torch.tanh = wrapFunc(torch.tanh, _tanh_flops_compute) - torch.Tensor.softmax = wrapFunc( - torch.Tensor.softmax, _softmax_flops_compute - ) + torch.Tensor.softmax = wrapFunc(torch.Tensor.softmax, _softmax_flops_compute) torch.sigmoid = wrapFunc(torch.sigmoid, _sigmoid_flops_compute) - torch.Tensor.sigmoid = wrapFunc( - torch.Tensor.sigmoid, _sigmoid_flops_compute - ) + torch.Tensor.sigmoid = wrapFunc(torch.Tensor.sigmoid, _sigmoid_flops_compute) def _reload_functionals(): @@ -732,15 +705,11 @@ def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): flops += rnn_module.hidden_size * 4 # two hadamard _product and add for C state flops += ( - rnn_module.hidden_size - + rnn_module.hidden_size - + rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size ) # final hadamard flops += ( - rnn_module.hidden_size - + rnn_module.hidden_size - + rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size ) return flops diff --git a/icefall/rnn_lm/check-onnx-streaming.py b/icefall/rnn_lm/check-onnx-streaming.py index d51a4b76b..28b908f82 100755 --- a/icefall/rnn_lm/check-onnx-streaming.py +++ b/icefall/rnn_lm/check-onnx-streaming.py @@ -112,7 +112,6 @@ def main(): for torch_v, onnx_v in zip( (torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0) ): - assert torch.allclose(torch_v, onnx_v, atol=1e-5), ( torch_v.shape, onnx_v.shape, diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 3d206d139..0178b80bf 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -463,7 +463,6 @@ def train_one_epoch( cur_batch_idx = params.get("cur_batch_idx", 0) for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: continue cur_batch_idx = batch_idx diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py index 7150297d6..231aca7f1 100755 --- a/icefall/shared/make_kn_lm.py +++ b/icefall/shared/make_kn_lm.py @@ -225,7 +225,6 @@ class NgramCounts: for n in range(0, self.ngram_order - 1): this_order_counts = self.counts[n] for hist, counts_for_hist in this_order_counts.items(): - n_star_star = 0 for w in counts_for_hist.word_to_count.keys(): n_star_star += len(counts_for_hist.word_to_context[w]) @@ -424,7 +423,6 @@ class NgramCounts: if __name__ == "__main__": - ngram_counts = NgramCounts(args.ngram_order) if args.text is None: diff --git a/icefall/transformer_lm/model.py b/icefall/transformer_lm/model.py index 79dda3168..c78cf1821 100644 --- a/icefall/transformer_lm/model.py +++ b/icefall/transformer_lm/model.py @@ -103,7 +103,6 @@ class TransformerLM(torch.nn.Module): return nll_loss def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None): - bs = x.size(0) state = None diff --git a/requirements-ci.txt b/requirements-ci.txt index 2433e190b..652e2ab47 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -20,6 +20,7 @@ kaldialign==0.7.1 sentencepiece==0.1.96 tensorboard==2.8.0 typeguard==2.13.3 +black==22.3.0 multi_quantization onnx diff --git a/requirements.txt b/requirements.txt index a07f6b7c7..f0098c236 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ sentencepiece>=0.1.96 tensorboard typeguard dill +black==22.3.0 From 97f9b9c33b9e3d4a7152c45f28dec397202aabb6 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Mon, 25 Sep 2023 10:48:50 +0800 Subject: [PATCH 044/113] Add documentation for RNNLM training (#1267) * add documentation for training an RNNLM --- .../decoding-with-langugage-models/index.rst | 5 +- docs/source/recipes/RNN-LM/index.rst | 7 ++ .../RNN-LM/librispeech/lm-training.rst | 104 ++++++++++++++++++ docs/source/recipes/index.rst | 1 + 4 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 docs/source/recipes/RNN-LM/index.rst create mode 100644 docs/source/recipes/RNN-LM/librispeech/lm-training.rst diff --git a/docs/source/decoding-with-langugage-models/index.rst b/docs/source/decoding-with-langugage-models/index.rst index 6e5e3a4d9..c49da9a4e 100644 --- a/docs/source/decoding-with-langugage-models/index.rst +++ b/docs/source/decoding-with-langugage-models/index.rst @@ -2,12 +2,13 @@ Decoding with language models ============================= This section describes how to use external langugage models -during decoding to improve the WER of transducer models. +during decoding to improve the WER of transducer models. To train an external language model, +please refer to this tutorial: :ref:`train_nnlm`. The following decoding methods with external langugage models are available: -.. list-table:: LM-rescoring-based methods vs shallow-fusion-based methods (The numbers in each field is WER on test-clean, WER on test-other and decoding time on test-clean) +.. list-table:: :widths: 25 50 :header-rows: 1 diff --git a/docs/source/recipes/RNN-LM/index.rst b/docs/source/recipes/RNN-LM/index.rst new file mode 100644 index 000000000..4b74e64c7 --- /dev/null +++ b/docs/source/recipes/RNN-LM/index.rst @@ -0,0 +1,7 @@ +RNN-LM +====== + +.. toctree:: + :maxdepth: 2 + + librispeech/lm-training \ No newline at end of file diff --git a/docs/source/recipes/RNN-LM/librispeech/lm-training.rst b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst new file mode 100644 index 000000000..736120275 --- /dev/null +++ b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst @@ -0,0 +1,104 @@ +.. _train_nnlm: + +Train an RNN langugage model +====================================== + +If you have enough text data, you can train a neural network language model (NNLM) to improve +the WER of your E2E ASR system. This tutorial shows you how to train an RNNLM from +scratch. + +.. HINT:: + + For how to use an NNLM during decoding, please refer to the following tutorials: + :ref:`shallow_fusion`, :ref:`LODR`, :ref:`rescoring` + +.. note:: + + This tutorial is based on the LibriSpeech recipe. Please check it out for the necessary + python scripts for this tutorial. We use the LibriSpeech LM-corpus as the LM training set + for illustration purpose. You can also collect your own data. The data format is quite simple: + each line should contain a complete sentence, and words should be separated by space. + +First, let's download the training data for the RNNLM. This can be done via the +following command: + +.. code-block:: bash + + $ wget https://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz + $ gzip -d librispeech-lm-norm.txt.gz + +As we are training a BPE-level RNNLM, we need to tokenize the training text, which requires a +BPE tokenizer. This can be achieved by executing the following command: + +.. code-block:: bash + + $ # if you don't have the BPE + $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + $ cd icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500 + $ git lfs pull --include bpe.model + $ cd ../../.. + + $ ./local/prepare_lm_training_data.py \ + --bpe-model icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500/bpe.model \ + --lm-data librispeech-lm-norm.txt \ + --lm-archive data/lang_bpe_500/lm_data.pt + +Now, you should have a file name ``lm_data.pt`` file store under the directory ``data/lang_bpe_500``. +This is the packed training data for the RNNLM. We then sort the training data according to its +sentence length. + +.. code-block:: bash + + $ # This could take a while (~ 20 minutes), feel free to grab a cup of coffee :) + $ ./local/sort_lm_training_data.py \ + --in-lm-data data/lang_bpe_500/lm_data.pt \ + --out-lm-data data/lang_bpe_500/sorted_lm_data.pt \ + --out-statistics data/lang_bpe_500/lm_data_stats.txt + + +The aforementioned steps can be repeated to create a a validation set for you RNNLM. Let's say +you have a validation set in ``valid.txt``, you can just set ``--lm-data valid.txt`` +and ``--lm-archive data/lang_bpe_500/lm-data-valid.pt`` when calling ``./local/prepare_lm_training_data.py``. + +After completing the previous steps, the training and testing sets for training RNNLM are ready. +The next step is to train the RNNLM model. The training command is as follows: + +.. code-block:: bash + + $ # assume you are in the icefall root directory + $ cd rnn_lm + $ ln -s ../../egs/librispeech/ASR/data . + $ cd .. + $ ./rnn_lm/train.py \ + --world-size 4 \ + --exp-dir ./rnn_lm/exp \ + --start-epoch 0 \ + --num-epochs 10 \ + --use-fp16 0 \ + --tie-weights 1 \ + --embedding-dim 2048 \ + --hidden_dim 2048 \ + --num-layers 3 \ + --batch-size 300 \ + --lm-data rnn_lm/data/lang_bpe_500/sorted_lm_data.pt \ + --lm-data-valid rnn_lm/data/lang_bpe_500/sorted_lm_data.pt + + +.. note:: + + You can adjust the RNNLM hyper parameters to control the size of the RNNLM, + such as embedding dimension and hidden state dimension. For more details, please + run ``./rnn_lm/train.py --help``. + +.. note:: + + The training of RNNLM can take a long time (usually a couple of days). + + + + + + + + + diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index 63793275c..7265e1cf6 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -15,3 +15,4 @@ We may add recipes for other tasks as well in the future. Non-streaming-ASR/index Streaming-ASR/index + RNN-LM/index From e17f884ace2dba7561d4d4eaaac6726234cad20f Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Mon, 25 Sep 2023 15:36:40 +0800 Subject: [PATCH 045/113] Fix docs for MVQ (#1272) * typo fix --- .../librispeech/distillation.rst | 16 ++++++++-------- egs/librispeech/ASR/distillation_with_hubert.sh | 2 ++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst index 2e8d0893a..37edf7de9 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst @@ -47,7 +47,7 @@ The data preparation contains several stages, you can use the following two options: - ``--stage`` - - ``--stop-stage`` + - ``--stop_stage`` to control which stage(s) should be run. By default, all stages are executed. @@ -56,8 +56,8 @@ For example, .. code-block:: bash $ cd egs/librispeech/ASR - $ ./prepare.sh --stage 0 --stop-stage 0 # run only stage 0 - $ ./prepare.sh --stage 2 --stop-stage 5 # run from stage 2 to stage 5 + $ ./prepare.sh --stage 0 --stop_stage 0 # run only stage 0 + $ ./prepare.sh --stage 2 --stop_stage 5 # run from stage 2 to stage 5 .. HINT:: @@ -108,15 +108,15 @@ As usual, you can control the stages you want to run by specifying the following two options: - ``--stage`` - - ``--stop-stage`` + - ``--stop_stage`` For example, .. code-block:: bash $ cd egs/librispeech/ASR - $ ./distillation_with_hubert.sh --stage 0 --stop-stage 0 # run only stage 0 - $ ./distillation_with_hubert.sh --stage 2 --stop-stage 4 # run from stage 2 to stage 5 + $ ./distillation_with_hubert.sh --stage 0 --stop_stage 0 # run only stage 0 + $ ./distillation_with_hubert.sh --stage 2 --stop_stage 4 # run from stage 2 to stage 5 Here are a few options in `./distillation_with_hubert.sh `_ you need to know before you proceed. @@ -134,7 +134,7 @@ and prepares MVQ-augmented training manifests. .. code-block:: bash - $ ./distillation_with_hubert.sh --stage 2 --stop-stage 2 # run only stage 2 + $ ./distillation_with_hubert.sh --stage 2 --stop_stage 2 # run only stage 2 Please see the following screenshot for the output of an example execution. @@ -172,7 +172,7 @@ To perform training, please run stage 3 by executing the following command. .. code-block:: bash - $ ./prepare.sh --stage 3 --stop-stage 3 # run MVQ training + $ ./prepare.sh --stage 3 --stop_stage 3 # run MVQ training Here is the code snippet for training: diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh index 6aaa0333b..a5b0b85af 100755 --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -56,6 +56,8 @@ use_extracted_codebook=True # "hubert_xtralarge_ll60k" -> pretrained model without fintuing teacher_model_id=hubert_xtralarge_ll60k_finetune_ls960 +. shared/parse_options.sh || exit 1 + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} From 1b565dd25198f700bcfe88e86a0f6a435e11a429 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 26 Sep 2023 15:41:39 +0800 Subject: [PATCH 046/113] added softlinks to local dir (#1273) --- egs/tedlium3/ASR/conformer_ctc2/local | 1 + egs/tedlium3/ASR/pruned_transducer_stateless/local | 1 + egs/tedlium3/ASR/transducer_stateless/local | 1 + egs/tedlium3/ASR/zipformer/local | 1 + 4 files changed, 4 insertions(+) create mode 120000 egs/tedlium3/ASR/conformer_ctc2/local create mode 120000 egs/tedlium3/ASR/pruned_transducer_stateless/local create mode 120000 egs/tedlium3/ASR/transducer_stateless/local create mode 120000 egs/tedlium3/ASR/zipformer/local diff --git a/egs/tedlium3/ASR/conformer_ctc2/local b/egs/tedlium3/ASR/conformer_ctc2/local new file mode 120000 index 000000000..c820590c5 --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/local @@ -0,0 +1 @@ +../local \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/local b/egs/tedlium3/ASR/pruned_transducer_stateless/local new file mode 120000 index 000000000..c820590c5 --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/local @@ -0,0 +1 @@ +../local \ No newline at end of file diff --git a/egs/tedlium3/ASR/transducer_stateless/local b/egs/tedlium3/ASR/transducer_stateless/local new file mode 120000 index 000000000..c820590c5 --- /dev/null +++ b/egs/tedlium3/ASR/transducer_stateless/local @@ -0,0 +1 @@ +../local \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/local b/egs/tedlium3/ASR/zipformer/local new file mode 120000 index 000000000..c820590c5 --- /dev/null +++ b/egs/tedlium3/ASR/zipformer/local @@ -0,0 +1 @@ +../local \ No newline at end of file From 2318c3fbd011b14ceffe8b3a8663057708afeea0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 26 Sep 2023 16:36:19 +0800 Subject: [PATCH 047/113] Support CTC decoding on CPU using OpenFst and kaldi decoders. (#1244) --- .flake8 | 1 + .../scripts/run-pre-trained-conformer-ctc.sh | 43 +++ .../run-pretrained-conformer-ctc.yml | 2 +- .github/workflows/run-yesno-recipe.yml | 37 ++ .gitignore | 2 + docs/source/model-export/export-ncnn.rst | 2 + .../jit_pretrained_decode_with_H.py | 235 ++++++++++++ .../jit_pretrained_decode_with_HL.py | 232 ++++++++++++ egs/librispeech/ASR/local/prepare_lang_fst.py | 127 +++++++ .../lstm_transducer_stateless/test_model.py | 3 +- egs/librispeech/ASR/prepare.sh | 4 + egs/yesno/ASR/local/prepare_lang_fst.py | 1 + egs/yesno/ASR/prepare.sh | 1 + egs/yesno/ASR/tdnn/jit_pretrained.py | 1 - .../ASR/tdnn/jit_pretrained_decode_with_H.py | 208 +++++++++++ .../ASR/tdnn/jit_pretrained_decode_with_HL.py | 207 +++++++++++ icefall/ctc/.gitignore | 2 + icefall/ctc/README.md | 17 + icefall/ctc/__init__.py | 6 + icefall/ctc/prepare_lang.py | 334 ++++++++++++++++++ icefall/ctc/test_ctc_topo.py | 140 ++++++++ icefall/ctc/test_prepare_lang.py | 43 +++ icefall/ctc/topo.py | 137 +++++++ requirements-ci.txt | 1 + requirements.txt | 1 + 25 files changed, 1783 insertions(+), 4 deletions(-) create mode 100755 egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py create mode 100755 egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py create mode 100755 egs/librispeech/ASR/local/prepare_lang_fst.py create mode 120000 egs/yesno/ASR/local/prepare_lang_fst.py create mode 100755 egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py create mode 100755 egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py create mode 100644 icefall/ctc/.gitignore create mode 100644 icefall/ctc/README.md create mode 100644 icefall/ctc/__init__.py create mode 100644 icefall/ctc/prepare_lang.py create mode 100755 icefall/ctc/test_ctc_topo.py create mode 100755 icefall/ctc/test_prepare_lang.py create mode 100644 icefall/ctc/topo.py diff --git a/.flake8 b/.flake8 index 1c0c2cdbb..410cb5482 100644 --- a/.flake8 +++ b/.flake8 @@ -24,6 +24,7 @@ exclude = **/data/**, icefall/shared/make_kn_lm.py, icefall/__init__.py + icefall/ctc/__init__.py ignore = # E203 white space before ":" diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index a4959aa01..19cbd96fc 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -44,3 +44,46 @@ log "HLG decoding" $repo/test_wavs/1089-134686-0001.flac \ $repo/test_wavs/1221-135766-0001.flac \ $repo/test_wavs/1221-135766-0002.flac + +log "CTC decoding on CPU with kaldi decoders using OpenFst" + +log "Exporting model with torchscript" + +pushd $repo/exp +ln -s pretrained.pt epoch-99.pt +popd + +./conformer_ctc/export.py \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --jit 1 + +ls -lh $repo/exp + + +log "Generating H.fst, HL.fst" + +./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 +ls -lh $repo/data/lang_bpe_500 + +log "Decoding with H on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --H $repo/data/lang_bpe_500/H.fst \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.flac \ + $repo/test_wavs/1221-135766-0001.flac \ + $repo/test_wavs/1221-135766-0002.flac + +log "Decoding with HL on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_HL.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HL $repo/data/lang_bpe_500/HL.fst \ + --words $repo/data/lang_bpe_500/words.txt \ + $repo/test_wavs/1089-134686-0001.flac \ + $repo/test_wavs/1221-135766-0001.flac \ + $repo/test_wavs/1221-135766-0002.flac diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml index 6151a5a14..e268d840d 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-conformer-ctc.yml @@ -29,7 +29,7 @@ concurrency: jobs: run_pre_trained_conformer_ctc: - if: github.event.label.name == 'ready' || github.event_name == 'push' + if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'ctc' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 57f15fe87..400595749 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -140,9 +140,46 @@ jobs: download/waves_yesno/0_0_0_1_0_0_0_1.wav \ download/waves_yesno/0_0_1_0_0_0_1_0.wav + - name: Test decoding with H + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + + python3 ./tdnn/jit_pretrained_decode_with_H.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --H ./data/lang_phone/H.fst \ + --tokens ./data/lang_phone/tokens.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + + - name: Test decoding with HL + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + + python3 ./tdnn/jit_pretrained_decode_with_HL.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HL ./data/lang_phone/HL.fst \ + --words ./data/lang_phone/words.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + - name: Show generated files shell: bash working-directory: ${{github.workspace}} run: | cd egs/yesno/ASR ls -lh tdnn/exp + ls -lh data/lang_phone diff --git a/.gitignore b/.gitignore index 8af05d884..fa18ca83c 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ node_modules *.param *.bin .DS_Store +*.fst +*.arpa diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst index 9eb5f85d2..634fb1e59 100644 --- a/docs/source/model-export/export-ncnn.rst +++ b/docs/source/model-export/export-ncnn.rst @@ -1,3 +1,5 @@ +.. _icefall_export_to_ncnn: + Export to ncnn ============== diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py new file mode 100755 index 000000000..b52c7cfed --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with H +on CPU using OpenFST and decoders from kaldi. + +Usage: + + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --H ./data/lang_bpe_500/H.fst \ + --tokens ./data/lang_bpe_500/tokens.txt \ + ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ + ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + +Note that to generate ./conformer_ctc/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldi_hmm_gmm +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./conformer_ctc/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument("--H", type=str, required=True, help="Path to H.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_tokens(tokens_txt: str) -> Dict[int, str]: + id2token = dict() + with open(tokens_txt, encoding="utf-8") as f: + for line in f: + token, idx = line.strip().split() + id2token[int(idx)] = token + + return id2token + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + H: kaldifst, + id2token: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + nnet_output: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + H: + The H graph. + id2token: + A map mapping token ID to token string. + Returns: + Return a list of decoded tokens. + """ + logging.info(f"{filename}, {nnet_output.shape}") + decodable = DecodableCtc(nnet_output.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(H, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + print(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for {filename}") + return [""] + + # tokens are incremented during graph construction + # so they need to be decremented + hyps = [id2token[i - 1] for i in osymbols_out] + # hyps = "".join(hyps).split("▁") + hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁ + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading H from {args.H}") + H = kaldifst.StdVectorFst.read(args.H) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.shape[0] for f in features] + feature_lengths = torch.tensor(feature_lengths) + + supervisions = dict() + supervisions["sequence_idx"] = torch.arange(len(features)) + supervisions["start_frame"] = torch.zeros(len(features)) + supervisions["num_frames"] = feature_lengths + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output, _, _ = model(features, supervisions) + feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2 + + id2token = read_tokens(args.tokens) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[i], + nnet_output=nnet_output[i, : feature_lengths[i]], + H=H, + id2token=id2token, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py new file mode 100755 index 000000000..f0326ccdf --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with H +on CPU using OpenFST and decoders from kaldi. + +Usage: + + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --HL ./data/lang_bpe_500/HL.fst \ + --words ./data/lang_bpe_500/words.txt \ + ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ + ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + +Note that to generate ./conformer_ctc/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldi_hmm_gmm +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./conformer_ctc/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HL", type=str, required=True, help="Path to HL.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_words(words_txt: str) -> Dict[int, str]: + id2word = dict() + with open(words_txt, encoding="utf-8") as f: + for line in f: + word, idx = line.strip().split() + id2word[int(idx)] = word + + return id2word + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + HL: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + nnet_output: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HL: + The HL graph. + word2token: + A map mapping token ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {nnet_output.shape}") + decodable = DecodableCtc(nnet_output.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HL, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + print(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading HL from {args.HL}") + HL = kaldifst.StdVectorFst.read(args.HL) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.shape[0] for f in features] + feature_lengths = torch.tensor(feature_lengths) + + supervisions = dict() + supervisions["sequence_idx"] = torch.arange(len(features)) + supervisions["start_frame"] = torch.zeros(len(features)) + supervisions["num_frames"] = feature_lengths + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output, _, _ = model(features, supervisions) + feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2 + + id2word = read_words(args.words) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[i], + nnet_output=nnet_output[i, : feature_lengths[i]], + HL=HL, + id2word=id2word, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/local/prepare_lang_fst.py b/egs/librispeech/ASR/local/prepare_lang_fst.py new file mode 100755 index 000000000..e8401123f --- /dev/null +++ b/egs/librispeech/ASR/local/prepare_lang_fst.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang) + +""" +This script takes as input lang_dir containing lexicon_disambig.txt, +tokens.txt, and words.txt and generates the following files: + + - H.fst + - HL.fst + +Note that saved files are in OpenFst binary format. + +Usage: + +./local/prepare_lang_fst.py \ + --lang-dir ./data/lang_phone \ + --has-silence 1 + +Or + +./local/prepare_lang_fst.py \ + --lang-dir ./data/lang_bpe_500 +""" + +import argparse +import logging +from pathlib import Path + +import kaldifst + +from icefall.ctc import ( + Lexicon, + add_disambig_self_loops, + add_one, + build_standard_ctc_topo, + make_lexicon_fst_no_silence, + make_lexicon_fst_with_silence, +) +from icefall.utils import str2bool + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + parser.add_argument( + "--has-silence", + type=str2bool, + default=False, + help="True if the lexicon has silence.", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = args.lang_dir + + lexicon = Lexicon(lang_dir) + + logging.info("Building standard CTC topology") + max_token_id = max(lexicon.tokens) + H = build_standard_ctc_topo(max_token_id=max_token_id) + + # We need to add one to all tokens since we want to use ID 0 + # for epsilon + add_one(H, treat_ilabel_zero_specially=False, update_olabel=True) + H.write(f"{lang_dir}/H.fst") + + logging.info("Building L") + # Now for HL + + if args.has_silence: + L = make_lexicon_fst_with_silence(lexicon, attach_symbol_table=False) + else: + L = make_lexicon_fst_no_silence(lexicon, attach_symbol_table=False) + + if args.has_silence: + # We also need to change the input labels of L + add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) + else: + add_one(L, treat_ilabel_zero_specially=False, update_olabel=False) + + # Invoke add_disambig_self_loops() so that it eats the disambig symbols + # from L after composition + add_disambig_self_loops( + H, + start=lexicon.token2id["#0"] + 1, + end=lexicon.max_disambig_id + 1, + ) + with open("H_1.fst.txt", "w") as f: + print(H, file=f) + + kaldifst.arcsort(H, sort_type="olabel") + kaldifst.arcsort(L, sort_type="ilabel") + + logging.info("Building HL") + HL = kaldifst.compose(H, L) + kaldifst.determinize_star(HL) + + disambig0 = lexicon.token2id["#0"] + 1 + max_disambig = lexicon.max_disambig_id + 1 + for state in kaldifst.StateIterator(HL): + for arc in kaldifst.ArcIterator(HL, state): + # If treat_ilabel_zero_specially is False, we always change it + # Otherwise, we only change non-zero input labels + if disambig0 <= arc.ilabel <= max_disambig: + arc.ilabel = 0 + + # Note: We are not composing L with G, so there is no need to add + # self-loops to L to handle #0 + + HL.write(f"{lang_dir}/HL.fst") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py b/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py index 03dfe1997..91ef53e24 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py @@ -57,8 +57,7 @@ def test_model(): convert_scaled_to_non_scaled(model, inplace=True) - if not os.path.exists(params.exp_dir): - os.path.mkdir(params.exp_dir) + params.exp_dir.mkdir(exist_ok=True) encoder_filename = params.exp_dir / "encoder_jit_trace.pt" export_encoder_model_jit_trace(model.encoder, encoder_filename) diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 8ce1eb478..fca2c6cc4 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -242,6 +242,10 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then $lang_dir/L_disambig.pt \ $lang_dir/L_disambig.fst fi + + if [ ! -f $lang_dir/HL.fst ]; then + ./local/prepare_lang_fst.py --lang-dir $lang_dir + fi done fi diff --git a/egs/yesno/ASR/local/prepare_lang_fst.py b/egs/yesno/ASR/local/prepare_lang_fst.py new file mode 120000 index 000000000..c5787c534 --- /dev/null +++ b/egs/yesno/ASR/local/prepare_lang_fst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/yesno/ASR/prepare.sh b/egs/yesno/ASR/prepare.sh index d4ef8d601..41db0cf7c 100755 --- a/egs/yesno/ASR/prepare.sh +++ b/egs/yesno/ASR/prepare.sh @@ -60,6 +60,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then ) > $lang_dir/lexicon.txt ./local/prepare_lang.py + ./local/prepare_lang_fst.py --lang-dir ./data/lang_phone --has-silence 1 fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py index 84390fca5..7581ecb83 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained.py @@ -156,7 +156,6 @@ def main(): features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - # Note: We don't use key padding mask for attention during decoding nnet_output = model(features) batch_size = nnet_output.shape[0] diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py new file mode 100755 index 000000000..209ab477a --- /dev/null +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with H +on CPU using OpenFST and decoders from kaldi. + +Usage: + + ./tdnn/jit_pretrained_decode_with_H.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --H ./data/lang_phone/H.fst \ + --tokens ./data/lang_phone/tokens.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + +Note that to generate ./tdnn/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./tdnn/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument("--H", type=str, required=True, help="Path to H.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_tokens(tokens_txt: str) -> Dict[int, str]: + id2token = dict() + with open(tokens_txt, encoding="utf-8") as f: + for line in f: + token, idx = line.strip().split() + id2token[int(idx)] = token + + return id2token + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + H: kaldifst, + id2token: Dict[int, str], +) -> List[str]: + decodable = DecodableCtc(nnet_output) + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(H, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + print(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2token[i - 1] for i in osymbols_out if id2token[i - 1] != "SIL"] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading H from {args.H}") + H = kaldifst.StdVectorFst.read(args.H) + + sample_rate = 8000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 23 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output = model(features) + + id2token = read_tokens(args.tokens) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[0], + nnet_output=nnet_output[i], + H=H, + id2token=id2token, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py new file mode 100755 index 000000000..74864e17d --- /dev/null +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with HL +on CPU using OpenFST and decoders from kaldi. + +Usage: + + ./tdnn/jit_pretrained_decode_with_HL.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HL ./data/lang_phone/HL.fst \ + --words ./data/lang_phone/words.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + +Note that to generate ./tdnn/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./tdnn/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HL", type=str, required=True, help="Path to HL.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_words(words_txt: str) -> Dict[int, str]: + id2word = dict() + with open(words_txt, encoding="utf-8") as f: + for line in f: + word, idx = line.strip().split() + id2word[int(idx)] = word + + return id2word + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + HL: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + decodable = DecodableCtc(nnet_output) + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HL, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + print(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for {filename}") + return [""] + + hyps = [id2word[i] for i in osymbols_out if id2word[i] != ""] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading HL from {args.HL}") + HL = kaldifst.StdVectorFst.read(args.HL) + + sample_rate = 8000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 23 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output = model(features) + + id2word = read_words(args.words) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[0], + nnet_output=nnet_output[i], + HL=HL, + id2word=id2word, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/ctc/.gitignore b/icefall/ctc/.gitignore new file mode 100644 index 000000000..8154cb57f --- /dev/null +++ b/icefall/ctc/.gitignore @@ -0,0 +1,2 @@ +*.pdf +*.gv diff --git a/icefall/ctc/README.md b/icefall/ctc/README.md new file mode 100644 index 000000000..07b0ff8cd --- /dev/null +++ b/icefall/ctc/README.md @@ -0,0 +1,17 @@ +# Introduction + +This folder uses [kaldifst][kaldifst] for graph construction +and decoders from [kaldi-hmm-gmm][kaldi-hmm-gmm] for CTC decoding. + +It supports only `CPU`. + +You can use + +```bash +pip install kaldifst kaldi-hmm-gmm +``` +to install the dependencies. + +[kaldi-hmm-gmm]: https://github.com/csukuangfj/kaldi-hmm-gmm +[kaldifst]: https://github.com/k2-fsa/kaldifst +[k2]: https://github.com/k2-fsa/k2 diff --git a/icefall/ctc/__init__.py b/icefall/ctc/__init__.py new file mode 100644 index 000000000..b546b31af --- /dev/null +++ b/icefall/ctc/__init__.py @@ -0,0 +1,6 @@ +from .prepare_lang import ( + Lexicon, + make_lexicon_fst_no_silence, + make_lexicon_fst_with_silence, +) +from .topo import add_disambig_self_loops, add_one, build_standard_ctc_topo diff --git a/icefall/ctc/prepare_lang.py b/icefall/ctc/prepare_lang.py new file mode 100644 index 000000000..4801b1beb --- /dev/null +++ b/icefall/ctc/prepare_lang.py @@ -0,0 +1,334 @@ +# Copyright 2023 Xiaomi Corp. (author: Fangjun Kuang) + +""" +The lang_dir should contain the following files: + - "lexicon_disambig.txt" + - "tokens.txt" + - "words.txt" +""" + +import math +from collections import defaultdict +from pathlib import Path +from typing import List, Tuple + +import kaldifst +import re + + +class Lexicon: + """Once constructed it is immutable""" + + def __init__( + self, + lang_dir: str, + disambig_pattern: str = re.compile(r"^#\d+$"), + ): + """ + Args: + lang_dir: + The path to the lang directory. We expect that it contains the + following files: + - lexicon_disambig.txt + - tokens.txt + - words.txt + + The format of the above files is described below. + + (1) lexicon_disambig.txt + + Each line in the lexicon_disambig.txt has the following format: + + word token1 token2 ... tokenN + + That is, the first field is the word, the remaining fields are + pronunciations of this word. Fields are separated by space(s). + + (2) tokens.txt + + Each line in tokens.txt has two fields separated by space(s): + + token ID + + The first field is the token symbol and the second filed is the + integer ID of the token. + + (3) words.txt + + Each line in words.txt has two fields separated by space(s): + + word ID + + The first field is the word symbol and the second filed is the + integer ID of the word. + disambig_pattern: + It contains the pattern for disambiguation symbols. + """ + lang_dir = Path(lang_dir) + + lexicon_txt = lang_dir / "lexicon_disambig.txt" + tokens_txt = lang_dir / "tokens.txt" + words_txt = lang_dir / "words.txt" + + assert lexicon_txt.is_file(), lexicon_txt + assert tokens_txt.is_file(), tokens_txt + assert words_txt.is_file(), words_txt + + self._read_lexicon(lexicon_txt) + self._read_tokens(tokens_txt) + self._read_words(words_txt) + + self.disambig_pattern = disambig_pattern + + max_disambig_id = -1 + for s, i in self.token2id.items(): + if self.disambig_pattern.match(s) and i > max_disambig_id: + max_disambig_id = i + + self.max_disambig_id = max_disambig_id + + def _read_lexicon(self, lexicon_txt: str): + word2phones = defaultdict(list) + with open(lexicon_txt, encoding="utf-8") as f: + for line in f: + word_phones = line.strip().split() + assert len(word_phones) >= 2, (word_phones, line) + word = word_phones[0] + phones: str = " ".join(word_phones[1:]) + word2phones[word].append(phones) + # We use a list here since a word may have multiple + # pronunciations + + self.word2phones = word2phones + + def _read_tokens(self, tokens_txt): + token2id = dict() + id2token = dict() + with open(tokens_txt, encoding="utf-8") as f: + for line in f: + token_id = line.strip().split() + assert len(token_id) == 2, token_id + + token = token_id[0] + idx = int(token_id[1]) + + assert token not in token2id, f"Duplicate token {line}" + assert idx not in id2token, f"Duplicate ID {line}" + + token2id[token] = idx + id2token[idx] = token + self.token2id = token2id + self.id2token = id2token + + def _read_words(self, words_txt): + word2id = dict() + id2word = dict() + with open(words_txt, encoding="utf-8") as f: + for line in f: + word_id = line.strip().split() + assert len(word_id) == 2, word_id + + word = word_id[0] + idx = int(word_id[1]) + + assert word not in word2id, f"Duplicate token {line}" + assert idx not in id2word, f"Duplicate ID {line}" + + word2id[word] = idx + id2word[idx] = word + + self.word2id = word2id + self.id2word = id2word + + def __iter__(self) -> Tuple[str, List[str]]: + for word, phones_list in self.word2phones.items(): + for phones in phones_list: + yield word, phones + + def __str__(self): + return str(self.word2phones) + + @property + def tokens(self) -> List[int]: + """Return a list of token IDs excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + ans = [] + for s in self.token2id: + if not self.disambig_pattern.match(s): + ans.append(self.token2id[s]) + if 0 in ans: + ans.remove(0) + ans.sort() + return ans + + +# See also +# http://vpanayotov.blogspot.com/2012/06/kaldi-decoding-graph-construction.html +def make_lexicon_fst_with_silence( + lexicon: Lexicon, + sil_prob: float = 0.5, + sil_phone: str = "SIL", + attach_symbol_table: bool = True, +) -> kaldifst.StdVectorFst: + phone2id = lexicon.token2id + word2id = lexicon.word2id + + assert sil_phone in phone2id + + assert sil_phone in phone2id, sil_phone + + sil_cost = -1 * math.log(sil_prob) + no_sil_cost = -1 * math.log(1.0 - sil_prob) + + fst = kaldifst.StdVectorFst() + + start_state = fst.add_state() + loop_state = fst.add_state() + sil_state = fst.add_state() + + fst.start = start_state + fst.set_final(state=loop_state, weight=0) + + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=0, + olabel=0, + weight=no_sil_cost, + nextstate=loop_state, + ), + ) + + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=0, + olabel=0, + weight=sil_cost, + nextstate=sil_state, + ), + ) + + fst.add_arc( + state=sil_state, + arc=kaldifst.StdArc( + ilabel=phone2id[sil_phone], + olabel=0, + weight=0, + nextstate=loop_state, + ), + ) + + for word, phones in lexicon: + phoneseq = phones.split() + pron_cost = 0 + cur_state = loop_state + + for i in range(len(phoneseq) - 1): + next_state = fst.add_state() + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]], + olabel=word2id[word] if i == 0 else 0, + weight=pron_cost if i == 0 else 0, + nextstate=next_state, + ), + ) + cur_state = next_state + + i = len(phoneseq) - 1 # note: i == -1 if phoneseq is empty. + + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]] if i >= 0 else 0, + olabel=word2id[word] if i <= 0 else 0, + weight=no_sil_cost + (pron_cost if i <= 0 else 0), + nextstate=loop_state, + ), + ) + + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]] if i >= 0 else 0, + olabel=word2id[word] if i <= 0 else 0, + weight=sil_cost + (pron_cost if i <= 0 else 0), + nextstate=sil_state, + ), + ) + + if attach_symbol_table: + isym = kaldifst.SymbolTable() + for p, i in phone2id.items(): + isym.add_symbol(symbol=p, key=i) + fst.input_symbols = isym + + osym = kaldifst.SymbolTable() + for w, i in word2id.items(): + osym.add_symbol(symbol=w, key=i) + fst.output_symbols = osym + + return fst + + +def make_lexicon_fst_no_silence( + lexicon: Lexicon, + attach_symbol_table: bool = True, +) -> kaldifst.StdVectorFst: + phone2id = lexicon.token2id + word2id = lexicon.word2id + + fst = kaldifst.StdVectorFst() + + start_state = fst.add_state() + fst.start = start_state + fst.set_final(state=start_state, weight=0) + + for word, phones in lexicon: + phoneseq = phones.split() + pron_cost = 0 + cur_state = start_state + + for i in range(len(phoneseq) - 1): + next_state = fst.add_state() + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]], + olabel=word2id[word] if i == 0 else 0, + weight=pron_cost if i == 0 else 0, + nextstate=next_state, + ), + ) + cur_state = next_state + + i = len(phoneseq) - 1 # note: i == -1 if phoneseq is empty. + + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]] if i >= 0 else 0, + olabel=word2id[word] if i <= 0 else 0, + weight=pron_cost if i <= 0 else 0, + nextstate=start_state, + ), + ) + + if attach_symbol_table: + isym = kaldifst.SymbolTable() + for p, i in phone2id.items(): + isym.add_symbol(symbol=p, key=i) + fst.input_symbols = isym + + osym = kaldifst.SymbolTable() + for w, i in word2id.items(): + osym.add_symbol(symbol=w, key=i) + fst.output_symbols = osym + + return fst diff --git a/icefall/ctc/test_ctc_topo.py b/icefall/ctc/test_ctc_topo.py new file mode 100755 index 000000000..4d4667209 --- /dev/null +++ b/icefall/ctc/test_ctc_topo.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +from pathlib import Path + +import graphviz +import kaldifst +import sentencepiece as spm +from prepare_lang import ( + Lexicon, + make_lexicon_fst_no_silence, + make_lexicon_fst_with_silence, +) +from topo import add_disambig_self_loops, add_one, build_standard_ctc_topo + + +def test_yesno(): + lang_dir = "/Users/fangjun/open-source/icefall/egs/yesno/ASR/data/lang_phone" + if not Path(lang_dir).is_dir(): + print(f"{lang_dir} does not exist! Skip testing") + return + lexicon = Lexicon(lang_dir) + max_token_id = max(lexicon.tokens) + + H = build_standard_ctc_topo(max_token_id=max_token_id) + + isym = kaldifst.SymbolTable() + isym.add_symbol(symbol="", key=0) + for i in range(1, max_token_id + 1): + isym.add_symbol(symbol=lexicon.id2token[i], key=i) + + osym = kaldifst.SymbolTable() + osym.add_symbol(symbol="", key=0) + for i in range(1, max_token_id + 1): + osym.add_symbol(symbol=lexicon.id2token[i], key=i) + + H.input_symbols = isym + H.output_symbols = osym + + fst_dot = kaldifst.draw(H, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="standard_ctc_topo_yesno.pdf") + # See the link below to visualize the above PDF + # https://t.ly/7uXZ9 + + # Now test HL + + # We need to add one to all tokens since we want to use ID 0 + # for epsilon + add_one(H, treat_ilabel_zero_specially=False, update_olabel=True) + + add_disambig_self_loops( + H, + start=lexicon.token2id["#0"] + 1, + end=lexicon.max_disambig_id, + ) + + fst_dot = kaldifst.draw(H, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="standard_ctc_topo_disambig_yesno.pdf") + + L = make_lexicon_fst_with_silence(lexicon) + + # We also need to change the input labels of L + add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) + + H.output_symbols = None + + kaldifst.arcsort(H, sort_type="olabel") + kaldifst.arcsort(L, sort_type="ilabel") + HL = kaldifst.compose(H, L) + + lexicon.id2token[0] = "" + lexicon.token2id[""] = 0 + + isym = kaldifst.SymbolTable() + isym.add_symbol(symbol="", key=0) + for i in range(0, lexicon.max_disambig_id + 1): + isym.add_symbol(symbol=lexicon.id2token[i], key=i + 1) + + osym = kaldifst.SymbolTable() + for i, word in lexicon.id2word.items(): + osym.add_symbol(symbol=word, key=i) + + HL.input_symbols = isym + HL.output_symbols = osym + + fst_dot = kaldifst.draw(HL, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="HL_yesno.pdf") + + +def test_librispeech(): + lang_dir = ( + "/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/data/lang_bpe_500" + ) + + if not Path(lang_dir).is_dir(): + print(f"{lang_dir} does not exist! Skip testing") + return + + lexicon = Lexicon(lang_dir) + HL = kaldifst.StdVectorFst.read(lang_dir + "/HL.fst") + + sp = spm.SentencePieceProcessor() + sp.load(lang_dir + "/bpe.model") + + i = lexicon.word2id["HELLOA"] + k = lexicon.word2id["WORLD"] + print(i, k) + s = f""" + 0 1 {i} {i} + 1 2 {k} {k} + 2 + """ + fst = kaldifst.compile( + s=s, + acceptor=False, + ) + + L = make_lexicon_fst_no_silence(lexicon, attach_symbol_table=False) + kaldifst.arcsort(L, sort_type="olabel") + with open("L.fst.txt", "w") as f: + print(L, file=f) + + fst = kaldifst.compose(L, fst) + print(fst) + fst_dot = kaldifst.draw(fst, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="a.pdf") + print(sp.encode(["HELLOA", "WORLD"])) + + +def main(): + test_yesno() + test_librispeech() + + +if __name__ == "__main__": + main() diff --git a/icefall/ctc/test_prepare_lang.py b/icefall/ctc/test_prepare_lang.py new file mode 100755 index 000000000..6c4b9e510 --- /dev/null +++ b/icefall/ctc/test_prepare_lang.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +from pathlib import Path + +import graphviz +import kaldifst +from prepare_lang import Lexicon, make_lexicon_fst_with_silence + + +def test_yesno(): + lang_dir = "/Users/fangjun/open-source/icefall/egs/yesno/ASR/data/lang_phone" + if not Path(lang_dir).is_dir(): + print(f"{lang_dir} does not exist! Skip testing") + return + + lexicon = Lexicon(lang_dir) + + L = make_lexicon_fst_with_silence(lexicon) + + isym = kaldifst.SymbolTable() + for i, token in lexicon.id2token.items(): + isym.add_symbol(symbol=token, key=i) + + osym = kaldifst.SymbolTable() + for i, word in lexicon.id2word.items(): + osym.add_symbol(symbol=word, key=i) + + L.input_symbols = isym + L.output_symbols = osym + fst_dot = kaldifst.draw(L, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="L_yesno.pdf") + # See the link below to visualize the above PDF + # https://t.ly/jMfXW + + +def main(): + test_yesno() + + +if __name__ == "__main__": + main() diff --git a/icefall/ctc/topo.py b/icefall/ctc/topo.py new file mode 100644 index 000000000..6a96dd038 --- /dev/null +++ b/icefall/ctc/topo.py @@ -0,0 +1,137 @@ +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +import kaldifst + + +# Note the name contains `standard`; it means there will be non-standard +# topologies. +def build_standard_ctc_topo(max_token_id: int) -> kaldifst.StdVectorFst: + """Build a standard CTC topology. + + Args: + Maximum valid token ID. We assume token IDs are contiguous + and starts from 0. In other words, the vocabulary size is + ``max_token_id + 1``. We assume the ID of the blank symbol is 0. + """ + # Token ID starts from 0 and there are as many states as the + # number of tokens. + # + # Note that epsilon is not a token and the token with ID 0 in tokens.txt + # is not an epsilon. It means input label 0 of the resulting FST does + # not represent an epsilon. + # + # You can use the function `add_one()` to modify the input/output labels + # of the resulting FST + + num_states = max_token_id + 1 + + # Step 1: Create as many states as the number of tokens. + # Each state is a final state + fst = kaldifst.StdVectorFst() + for i in range(num_states): + s = fst.add_state() + fst.set_final(state=s, weight=0) + + # Step 2: Set state 0 as the start state. + # We assume the ID of the blank symbol is 0. + fst.start = 0 + + # Step 3: Build a fully connected graph. + for i in range(num_states): + for k in range(num_states): + fst.add_arc( + state=i, + arc=kaldifst.StdArc( + ilabel=k, + olabel=k if i != k else 0, # if i==k, it is a self loop + weight=0, + nextstate=k, + ), + ) + # Please see ./test_ctc_topo.py if you want to know what the resulting + # FST looks like + + return fst + + +def add_one( + fst: kaldifst.StdVectorFst, + treat_ilabel_zero_specially: bool, + update_olabel: bool, +) -> None: + """Modify the input and output labels of the given FST in-place. + + Args: + fst: + The FST to be modified. It is changed in-place. + treat_ilabel_zero_specially: + If True, then every non-zero input label is increased by one and the + zero input label is not changed. + If False, then every input label is increased by one. + update_olabel: + If False, the output label is not changed. + If True, then every non-zero output label is increased by one. + In either case, output label with 0 is not changed. + """ + for state in kaldifst.StateIterator(fst): + for arc in kaldifst.ArcIterator(fst, state): + # If treat_ilabel_zero_specially is False, we always change it + # Otherwise, we only change non-zero input labels + if treat_ilabel_zero_specially is False or arc.ilabel != 0: + arc.ilabel += 1 + + if update_olabel and arc.olabel != 0: + arc.olabel += 1 + + if fst.input_symbols is not None: + input_symbols = kaldifst.SymbolTable() + input_symbols.add_symbol(symbol="", key=0) + + for i in range(0, fst.input_symbols.num_symbols()): + s = fst.input_symbols.find(i) + input_symbols.add_symbol(symbol=s, key=i + 1) + + fst.input_symbols = input_symbols + + if update_olabel and fst.output_symbols is not None: + output_symbols = kaldifst.SymbolTable() + output_symbols.add_symbol(symbol="", key=0) + + for i in range(0, fst.output_symbols.num_symbols()): + s = fst.output_symbols.find(i) + output_symbols.add_symbol(symbol=s, key=i + 1) + + fst.output_symbols = output_symbols + + +def add_disambig_self_loops(fst: kaldifst.StdVectorFst, start: int, end: int): + """Add self-loops to each state. + + For each disambig symbol, we add a self-loop with input label disambig_id + and output label diambig_id of that disambig symbol. + + Args: + fst: + It is changed in-place. + start: + The ID of #0 + end: + The ID of the last disambig symbol. For instance if there are 3 + disambig symbols ``#0``, ``#1``, and ``#2``, then ``end`` is the ID + of ``#2``. + """ + for state in kaldifst.StateIterator(fst): + for i in range(start, end + 1): + fst.add_arc( + state=state, + arc=kaldifst.StdArc( + ilabel=i, + olabel=i, + weight=0, + nextstate=state, + ), + ) + + if fst.output_symbols: + for i in range(start, end + 1): + fst.output_symbols.add_symbol(symbol=f"#{i-start}", key=i) diff --git a/requirements-ci.txt b/requirements-ci.txt index 652e2ab47..6f8739ce0 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -27,3 +27,4 @@ onnx onnxmltools onnxruntime kaldifst +kaldi-hmm-gmm diff --git a/requirements.txt b/requirements.txt index f0098c236..c031d683c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ kaldifst kaldilm kaldialign +kaldi-hmm-gmm sentencepiece>=0.1.96 tensorboard typeguard From 772ee3955bcfcfcaf76c06aeb40f69765609f7b4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 27 Sep 2023 14:49:27 +0800 Subject: [PATCH 048/113] Support HLG decoding using OpenFst with kaldi decoders (#1275) --- .../scripts/run-pre-trained-conformer-ctc.sh | 61 +++-- .../run-pretrained-conformer-ctc.yml | 9 +- .../jit_pretrained_decode_with_HL.py | 4 +- .../jit_pretrained_decode_with_HLG.py | 232 ++++++++++++++++++ egs/librispeech/ASR/local/prepare_lang_fst.py | 160 +++++++++--- egs/librispeech/ASR/prepare.sh | 2 +- 6 files changed, 412 insertions(+), 56 deletions(-) create mode 100755 egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index 19cbd96fc..a82d85fb2 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -10,16 +10,30 @@ log() { cd egs/librispeech/ASR -repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 -git lfs install - +# repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09 log "Downloading pre-trained model from $repo_url" -git clone $repo_url +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) +pushd $repo + +git lfs pull --include "exp/pretrained.pt" +git lfs pull --include "data/lang_bpe_500/HLG.pt" +git lfs pull --include "data/lang_bpe_500/L.pt" +git lfs pull --include "data/lang_bpe_500/L_disambig.pt" +git lfs pull --include "data/lang_bpe_500/Linv.pt" +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "data/lang_bpe_500/lexicon.txt" +git lfs pull --include "data/lang_bpe_500/lexicon_disambig.txt" +git lfs pull --include "data/lang_bpe_500/tokens.txt" +git lfs pull --include "data/lang_bpe_500/words.txt" +git lfs pull --include "data/lm/G_3_gram.fst.txt" + +popd log "Display test files" tree $repo/ -ls -lh $repo/test_wavs/*.flac +ls -lh $repo/test_wavs/*.wav log "CTC decoding" @@ -28,9 +42,9 @@ log "CTC decoding" --num-classes 500 \ --checkpoint $repo/exp/pretrained.pt \ --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.flac \ - $repo/test_wavs/1221-135766-0001.flac \ - $repo/test_wavs/1221-135766-0002.flac + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav log "HLG decoding" @@ -41,9 +55,9 @@ log "HLG decoding" --tokens $repo/data/lang_bpe_500/tokens.txt \ --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ - $repo/test_wavs/1089-134686-0001.flac \ - $repo/test_wavs/1221-135766-0001.flac \ - $repo/test_wavs/1221-135766-0002.flac + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav log "CTC decoding on CPU with kaldi decoders using OpenFst" @@ -65,7 +79,8 @@ ls -lh $repo/exp log "Generating H.fst, HL.fst" -./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 +./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 --ngram-G $repo/data/lm/G_3_gram.fst.txt + ls -lh $repo/data/lang_bpe_500 log "Decoding with H on CPU with OpenFst" @@ -74,9 +89,9 @@ log "Decoding with H on CPU with OpenFst" --nn-model $repo/exp/cpu_jit.pt \ --H $repo/data/lang_bpe_500/H.fst \ --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.flac \ - $repo/test_wavs/1221-135766-0001.flac \ - $repo/test_wavs/1221-135766-0002.flac + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav log "Decoding with HL on CPU with OpenFst" @@ -84,6 +99,16 @@ log "Decoding with HL on CPU with OpenFst" --nn-model $repo/exp/cpu_jit.pt \ --HL $repo/data/lang_bpe_500/HL.fst \ --words $repo/data/lang_bpe_500/words.txt \ - $repo/test_wavs/1089-134686-0001.flac \ - $repo/test_wavs/1221-135766-0001.flac \ - $repo/test_wavs/1221-135766-0002.flac + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Decoding with HLG on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HLG $repo/data/lang_bpe_500/HLG.fst \ + --words $repo/data/lang_bpe_500/words.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml index e268d840d..54845159d 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-conformer-ctc.yml @@ -23,13 +23,20 @@ on: pull_request: types: [labeled] + workflow_dispatch: + inputs: + test-run: + description: 'Test (y/n)?' + required: true + default: 'y' + concurrency: group: run_pre_trained_conformer_ctc-${{ github.ref }} cancel-in-progress: true jobs: run_pre_trained_conformer_ctc: - if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'ctc' + if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py index f0326ccdf..3420c4da3 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py @@ -2,12 +2,12 @@ # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) """ -This file shows how to use a torchscript model for decoding with H +This file shows how to use a torchscript model for decoding with HL on CPU using OpenFST and decoders from kaldi. Usage: - ./conformer_ctc/jit_pretrained_decode_with_H.py \ + ./conformer_ctc/jit_pretrained_decode_with_HL.py \ --nn-model ./conformer_ctc/exp/cpu_jit.pt \ --HL ./data/lang_bpe_500/HL.fst \ --words ./data/lang_bpe_500/words.txt \ diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py new file mode 100755 index 000000000..42129f073 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with HLG +on CPU using OpenFST and decoders from kaldi. + +Usage: + + ./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --HLG ./data/lang_bpe_500/HLG.fst \ + --words ./data/lang_bpe_500/words.txt \ + ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ + ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + +Note that to generate ./conformer_ctc/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldi_hmm_gmm +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./conformer_ctc/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_words(words_txt: str) -> Dict[int, str]: + id2word = dict() + with open(words_txt, encoding="utf-8") as f: + for line in f: + word, idx = line.strip().split() + id2word[int(idx)] = word + + return id2word + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + HLG: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + nnet_output: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HLG: + The HLG graph. + word2token: + A map mapping token ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {nnet_output.shape}") + decodable = DecodableCtc(nnet_output.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HLG, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + print(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading HLG from {args.HLG}") + HLG = kaldifst.StdVectorFst.read(args.HLG) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.shape[0] for f in features] + feature_lengths = torch.tensor(feature_lengths) + + supervisions = dict() + supervisions["sequence_idx"] = torch.arange(len(features)) + supervisions["start_frame"] = torch.zeros(len(features)) + supervisions["num_frames"] = feature_lengths + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output, _, _ = model(features, supervisions) + feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2 + + id2word = read_words(args.words) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[i], + nnet_output=nnet_output[i, : feature_lengths[i]], + HLG=HLG, + id2word=id2word, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/local/prepare_lang_fst.py b/egs/librispeech/ASR/local/prepare_lang_fst.py index e8401123f..fb1e7f9c0 100755 --- a/egs/librispeech/ASR/local/prepare_lang_fst.py +++ b/egs/librispeech/ASR/local/prepare_lang_fst.py @@ -8,6 +8,7 @@ tokens.txt, and words.txt and generates the following files: - H.fst - HL.fst + - HLG.fst Note that saved files are in OpenFst binary format. @@ -56,9 +57,114 @@ def get_args(): help="True if the lexicon has silence.", ) + parser.add_argument( + "--ngram-G", + type=str, + help="""If not empty, it is the filename of G used to build HLG. + For instance, --ngram-G=./data/lm/G_3_fst.txt + """, + ) + return parser.parse_args() +def build_HL( + H: kaldifst.StdVectorFst, + L: kaldifst.StdVectorFst, + has_silence: bool, + lexicon: Lexicon, +) -> kaldifst.StdVectorFst: + if has_silence: + # We also need to change the input labels of L + add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) + else: + add_one(L, treat_ilabel_zero_specially=False, update_olabel=False) + + # Invoke add_disambig_self_loops() so that it eats the disambig symbols + # from L after composition + add_disambig_self_loops( + H, + start=lexicon.token2id["#0"] + 1, + end=lexicon.max_disambig_id + 1, + ) + + kaldifst.arcsort(H, sort_type="olabel") + kaldifst.arcsort(L, sort_type="ilabel") + + HL = kaldifst.compose(H, L) + kaldifst.determinize_star(HL) + + disambig0 = lexicon.token2id["#0"] + 1 + max_disambig = lexicon.max_disambig_id + 1 + for state in kaldifst.StateIterator(HL): + for arc in kaldifst.ArcIterator(HL, state): + # If treat_ilabel_zero_specially is False, we always change it + # Otherwise, we only change non-zero input labels + if disambig0 <= arc.ilabel <= max_disambig: + arc.ilabel = 0 + + # Note: We are not composing L with G, so there is no need to add + # self-loops to L to handle #0 + + return HL + + +def build_HLG( + H: kaldifst.StdVectorFst, + L: kaldifst.StdVectorFst, + G: kaldifst.StdVectorFst, + has_silence: bool, + lexicon: Lexicon, +) -> kaldifst.StdVectorFst: + if has_silence: + # We also need to change the input labels of L + add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) + else: + add_one(L, treat_ilabel_zero_specially=False, update_olabel=False) + + # add-self-loops + token_disambig0 = lexicon.token2id["#0"] + 1 + word_disambig0 = lexicon.word2id["#0"] + + kaldifst.add_self_loops(L, isyms=[token_disambig0], osyms=[word_disambig0]) + + kaldifst.arcsort(L, sort_type="olabel") + kaldifst.arcsort(G, sort_type="ilabel") + LG = kaldifst.compose(L, G) + kaldifst.determinize_star(LG) + kaldifst.minimize_encoded(LG) + + kaldifst.arcsort(LG, sort_type="ilabel") + + # Invoke add_disambig_self_loops() so that it eats the disambig symbols + # from L after composition + add_disambig_self_loops( + H, + start=lexicon.token2id["#0"] + 1, + end=lexicon.max_disambig_id + 1, + ) + + kaldifst.arcsort(H, sort_type="olabel") + + HLG = kaldifst.compose(H, LG) + kaldifst.determinize_star(HLG) + + disambig0 = lexicon.token2id["#0"] + 1 + max_disambig = lexicon.max_disambig_id + 1 + for state in kaldifst.StateIterator(HLG): + for arc in kaldifst.ArcIterator(HLG, state): + # If treat_ilabel_zero_specially is False, we always change it + # Otherwise, we only change non-zero input labels + if disambig0 <= arc.ilabel <= max_disambig: + arc.ilabel = 0 + return HLG + + +def copy_fst(fst): + # Please don't use fst.copy() + return kaldifst.StdVectorFst(fst) + + def main(): args = get_args() lang_dir = args.lang_dir @@ -82,43 +188,29 @@ def main(): else: L = make_lexicon_fst_no_silence(lexicon, attach_symbol_table=False) - if args.has_silence: - # We also need to change the input labels of L - add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) - else: - add_one(L, treat_ilabel_zero_specially=False, update_olabel=False) - - # Invoke add_disambig_self_loops() so that it eats the disambig symbols - # from L after composition - add_disambig_self_loops( - H, - start=lexicon.token2id["#0"] + 1, - end=lexicon.max_disambig_id + 1, - ) - with open("H_1.fst.txt", "w") as f: - print(H, file=f) - - kaldifst.arcsort(H, sort_type="olabel") - kaldifst.arcsort(L, sort_type="ilabel") - logging.info("Building HL") - HL = kaldifst.compose(H, L) - kaldifst.determinize_star(HL) - - disambig0 = lexicon.token2id["#0"] + 1 - max_disambig = lexicon.max_disambig_id + 1 - for state in kaldifst.StateIterator(HL): - for arc in kaldifst.ArcIterator(HL, state): - # If treat_ilabel_zero_specially is False, we always change it - # Otherwise, we only change non-zero input labels - if disambig0 <= arc.ilabel <= max_disambig: - arc.ilabel = 0 - - # Note: We are not composing L with G, so there is no need to add - # self-loops to L to handle #0 - + HL = build_HL( + H=copy_fst(H), + L=copy_fst(L), + has_silence=args.has_silence, + lexicon=lexicon, + ) HL.write(f"{lang_dir}/HL.fst") + if not args.ngram_G: + logging.info("Skip building HLG") + return + + logging.info("Building HLG") + with open(args.ngram_G) as f: + G = kaldifst.compile( + s=f.read(), + acceptor=False, + ) + + HLG = build_HLG(H=H, L=L, G=G, has_silence=args.has_silence, lexicon=lexicon) + HLG.write(f"{lang_dir}/HLG.fst") + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index fca2c6cc4..93d010ea8 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -244,7 +244,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then fi if [ ! -f $lang_dir/HL.fst ]; then - ./local/prepare_lang_fst.py --lang-dir $lang_dir + ./local/prepare_lang_fst.py --lang-dir $lang_dir --ngram-G ./data/lm/G_3_gram.fst.txt fi done fi From a5ba1133c4cc6755530217876e8ff3bfb64e4d36 Mon Sep 17 00:00:00 2001 From: yaguang Date: Wed, 27 Sep 2023 17:33:38 +0800 Subject: [PATCH 049/113] Compatible with new lhotse versions. (#1278) --- egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index 180930747..6abe6c084 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -198,7 +198,7 @@ class AishellAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") From 8181d19860cbec21593e32af99deb4959f762540 Mon Sep 17 00:00:00 2001 From: yaguang Date: Wed, 27 Sep 2023 17:35:26 +0800 Subject: [PATCH 050/113] check bbpe model exists in advance. (#1277) --- egs/aishell/ASR/local/train_bbpe_model.py | 37 +++++++++++------------ 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/egs/aishell/ASR/local/train_bbpe_model.py b/egs/aishell/ASR/local/train_bbpe_model.py index d231d5d77..48160897d 100755 --- a/egs/aishell/ASR/local/train_bbpe_model.py +++ b/egs/aishell/ASR/local/train_bbpe_model.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - # You can install sentencepiece via: # # pip install sentencepiece @@ -26,12 +25,12 @@ # Please install a version >=0.1.96 import argparse -import re import shutil import tempfile from pathlib import Path import sentencepiece as spm + from icefall import byte_encode, tokenize_by_CJK_char @@ -74,6 +73,11 @@ def main(): model_type = "unigram" model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + model_file = Path(model_prefix + ".model") + if model_file.is_file(): + print(f"{model_file} exists - skipping") + return + character_coverage = 1.0 input_sentence_size = 100000000 @@ -88,23 +92,18 @@ def main(): _convert_to_bchar(args.transcript, train_text) - model_file = Path(model_prefix + ".model") - if not model_file.is_file(): - spm.SentencePieceTrainer.train( - input=train_text, - vocab_size=vocab_size, - model_type=model_type, - model_prefix=model_prefix, - input_sentence_size=input_sentence_size, - character_coverage=character_coverage, - user_defined_symbols=user_defined_symbols, - unk_id=unk_id, - bos_id=-1, - eos_id=-1, - ) - else: - print(f"{model_file} exists - skipping") - return + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) shutil.copyfile(model_file, f"{lang_dir}/bbpe.model") From 3abc290c1119d3adbb64112d854a9973ec486b3d Mon Sep 17 00:00:00 2001 From: Dongji Gao Date: Thu, 28 Sep 2023 19:52:46 -0400 Subject: [PATCH 051/113] Add scripts and recipe for BTC/OTC (#1255) --- egs/librispeech/WSASR/README.md | 224 ++++ .../WSASR/conformer_ctc2/__init__.py | 1 + .../WSASR/conformer_ctc2/asr_datamodule.py | 369 ++++++ .../WSASR/conformer_ctc2/attention.py | 1 + .../WSASR/conformer_ctc2/conformer.py | 949 ++++++++++++++ .../WSASR/conformer_ctc2/decode.py | 718 +++++++++++ .../WSASR/conformer_ctc2/export.py | 1 + .../WSASR/conformer_ctc2/label_smoothing.py | 1 + egs/librispeech/WSASR/conformer_ctc2/optim.py | 1 + .../WSASR/conformer_ctc2/scaling.py | 1 + .../WSASR/conformer_ctc2/subsampling.py | 184 +++ egs/librispeech/WSASR/conformer_ctc2/train.py | 1115 +++++++++++++++++ .../WSASR/conformer_ctc2/transformer.py | 1055 ++++++++++++++++ egs/librispeech/WSASR/figures/del.png | Bin 0 -> 14544 bytes egs/librispeech/WSASR/figures/ins.png | Bin 0 -> 16947 bytes .../WSASR/figures/otc_emission.drawio.png | Bin 0 -> 39476 bytes egs/librispeech/WSASR/figures/otc_g.png | Bin 0 -> 33339 bytes .../figures/otc_training_graph.drawio.png | Bin 0 -> 154014 bytes egs/librispeech/WSASR/figures/sub.png | Bin 0 -> 15900 bytes egs/librispeech/WSASR/local/compile_hlg.py | 173 +++ .../WSASR/local/compute_fbank_librispeech.py | 162 +++ .../WSASR/local/compute_ssl_librispeech.py | 100 ++ egs/librispeech/WSASR/local/filter_cuts.py | 160 +++ .../WSASR/local/get_words_from_lexicon.py | 48 + .../WSASR/local/make_error_cutset.py | 175 +++ egs/librispeech/WSASR/local/prepare_lang.py | 413 ++++++ .../WSASR/local/prepare_otc_lang_bpe.py | 295 +++++ .../WSASR/local/train_bpe_model.py | 100 ++ .../WSASR/local/validate_bpe_lexicon.py | 85 ++ .../WSASR/local/validate_manifest.py | 92 ++ egs/librispeech/WSASR/prepare.sh | 233 ++++ icefall/otc_graph_compiler.py | 246 ++++ icefall/utils.py | 64 + 33 files changed, 6966 insertions(+) create mode 100644 egs/librispeech/WSASR/README.md create mode 120000 egs/librispeech/WSASR/conformer_ctc2/__init__.py create mode 100644 egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py create mode 120000 egs/librispeech/WSASR/conformer_ctc2/attention.py create mode 100644 egs/librispeech/WSASR/conformer_ctc2/conformer.py create mode 100755 egs/librispeech/WSASR/conformer_ctc2/decode.py create mode 120000 egs/librispeech/WSASR/conformer_ctc2/export.py create mode 120000 egs/librispeech/WSASR/conformer_ctc2/label_smoothing.py create mode 120000 egs/librispeech/WSASR/conformer_ctc2/optim.py create mode 120000 egs/librispeech/WSASR/conformer_ctc2/scaling.py create mode 100644 egs/librispeech/WSASR/conformer_ctc2/subsampling.py create mode 100755 egs/librispeech/WSASR/conformer_ctc2/train.py create mode 100644 egs/librispeech/WSASR/conformer_ctc2/transformer.py create mode 100644 egs/librispeech/WSASR/figures/del.png create mode 100644 egs/librispeech/WSASR/figures/ins.png create mode 100644 egs/librispeech/WSASR/figures/otc_emission.drawio.png create mode 100644 egs/librispeech/WSASR/figures/otc_g.png create mode 100644 egs/librispeech/WSASR/figures/otc_training_graph.drawio.png create mode 100644 egs/librispeech/WSASR/figures/sub.png create mode 100755 egs/librispeech/WSASR/local/compile_hlg.py create mode 100755 egs/librispeech/WSASR/local/compute_fbank_librispeech.py create mode 100755 egs/librispeech/WSASR/local/compute_ssl_librispeech.py create mode 100644 egs/librispeech/WSASR/local/filter_cuts.py create mode 100755 egs/librispeech/WSASR/local/get_words_from_lexicon.py create mode 100755 egs/librispeech/WSASR/local/make_error_cutset.py create mode 100755 egs/librispeech/WSASR/local/prepare_lang.py create mode 100755 egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py create mode 100755 egs/librispeech/WSASR/local/train_bpe_model.py create mode 100755 egs/librispeech/WSASR/local/validate_bpe_lexicon.py create mode 100755 egs/librispeech/WSASR/local/validate_manifest.py create mode 100755 egs/librispeech/WSASR/prepare.sh create mode 100644 icefall/otc_graph_compiler.py diff --git a/egs/librispeech/WSASR/README.md b/egs/librispeech/WSASR/README.md new file mode 100644 index 000000000..3b8822fd2 --- /dev/null +++ b/egs/librispeech/WSASR/README.md @@ -0,0 +1,224 @@ +# Introduction + +This is a weakly supervised ASR recipe for the LibriSpeech (clean 100 hours) dataset. We train a +conformer model using [Bypass Temporal Classification](https://arxiv.org/pdf/2306.01031.pdf) (BTC)/[Omni-temporal Classification](https://arxiv.org/pdf/2309.15796.pdf) (OTC) with transcripts with synthetic errors. In this README, we will describe +the task and the BTC/OTC training process. + +Note that OTC is an extension of BTC and supports all BTC functions. Therefore, in the following, we only describe OTC. +## Task +We propose BTC/OTC to directly train an ASR system leveraging weak supervision, i.e., speech with non-verbatim transcripts. This is achieved by using a special token $\star$ to model uncertainties (i.e., substitution errors, insertion errors, and deletion errors) +within the WFST framework during training. + + +

+
+ Image 1 + +
+
+ Image 2 + +
+
+ Image 3 + +
+
+
Examples of errors (substitution, insertion, and deletion) in the transcript. The grey box is the verbatim transcript and the red box is the inaccurate transcript. Inaccurate words are marked in bold.


+ + +We modify $G(\mathbf{y})$ by adding self-loop arcs into each state and bypass arcs into each arc. +

+ Image Alt Text + +

+ +We incorporate the penalty strategy and apply different configurations for the self-loop arc and bypass arc. The penalties are set as + +$$\lambda_{1_{i}} = \beta_{1} * \tau_{1}^{i},\quad \lambda_{2_{i}} = \beta_{2} * \tau_{2}^{i}$$ + +for the $i$-th training epoch. $\beta$ is the initial penalty that encourages the model to rely more on the given transcript at the start of training. +It decays exponentially by a factor of $\tau \in (0, 1)$, gradually encouraging the model to align speech with $\star$ when getting confused. + +After composing the modified WFST $G_{\text{otc}}(\mathbf{y})$ with $L$ and $T$, the OTC training graph is shown in this figure: +
+ Image Alt Text +
OTC training graph. The self-loop arcs and bypass arcs are highlighted in green and blue, respectively.
+
+ +The $\star$ is represented as the average probability of all non-blank tokens. +

+ +

+ +The weight of $\star$ is the log average probability of "a" and "b": $\log \frac{e^{-1.2} + e^{-2.3}}{2} = -1.6$ and $\log \frac{e^{-1.9} + e^{-0.5}}{2} = -1.0$ for 2 frames. + +## Description of the recipe +### Preparation +``` +# feature_type can be ssl or fbank +feature_type=ssl +feature_dir="data/${feature_type}" +manifest_dir="${feature_dir}" +lang_dir="data/lang" +lm_dir="data/lm" +exp_dir="conformer_ctc2/exp" +otc_token="" + +./prepare.sh \ + --feature-type "${feature_type}" \ + --feature-dir "${feature_dir}" \ + --lang-dir "${lang_dir}" \ + --lm-dir "${lm_dir}" \ + --otc-token "${otc_token}" +``` +This script adds the 'otc_token' ('\') and its corresponding sentence-piece ('▁\') to 'words.txt' and 'tokens.txt,' respectively. Additionally, it computes SSL features using the 'wav2vec2-base' model. (You can use GPU to accelerate feature extraction). + +### Making synthetic errors to the transcript (train-clean-100) [optional] +``` +sub_er=0.17 +ins_er=0.17 +del_er=0.17 +synthetic_train_manifest="librispeech_cuts_train-clean-100_${sub_er}_${ins_er}_${del_er}.jsonl.gz" + +./local/make_error_cutset.py \ + --input-cutset "${manifest_dir}/librispeech_cuts_train-clean-100.jsonl.gz" \ + --words-file "${lang_dir}/words.txt" \ + --sub-error-rate "${sub_er}" \ + --ins-error-rate "${ins_er}" \ + --del-error-rate "${del_er}" \ + --output-cutset "${manifest_dir}/${synthetic_train_manifest}" +``` +This script generates synthetic substitution, insertion, and deletion errors in the transcript with ratios 'sub_er', 'ins_er', and 'del_er', respectively. The original transcript is saved as 'verbatim transcript' in the cutset, along with information on how the transcript is corrupted: + + - '[hello]' indicates the original word 'hello' is substituted by another word + - '[]' indicates an extra word is inserted into the transcript + - '-hello-' indicates the word 'hello' is deleted from the transcript + +So if the original transcript is "have a nice day" and the synthetic one is "a very good day", the 'verbatim transcript' would be: +``` +original: have a nice day +synthetic: a very good day +verbatim: -have- a [] [nice] day +``` + +### Training +The training uses synthetic data based on the train-clean-100 subset. +``` +otc_lang_dir=data/lang_bpe_200 + +allow_bypass_arc=true +allow_self_loop_arc=true +initial_bypass_weight=-19 +initial_self_loop_weight=3.75 +bypass_weight_decay=0.975 +self_loop_weight_decay=0.999 + +show_alignment=true + +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./conformer_ctc2/train.py \ + --world-size 4 \ + --manifest-dir "${manifest_dir}" \ + --train-manifest "${synthetic_train_manifest}" \ + --exp-dir "${exp_dir}" \ + --lang-dir "${otc_lang_dir}" \ + --otc-token "${otc_token}" \ + --allow-bypass-arc "${allow_bypass_arc}" \ + --allow-self-loop-arc "${allow_self_loop_arc}" \ + --initial-bypass-weight "${initial_bypass_weight}" \ + --initial-self-loop-weight "${initial_self_loop_weight}" \ + --bypass-weight-decay "${bypass_weight_decay}" \ + --self-loop-weight-decay "${self_loop_weight_decay}" \ + --show-alignment "${show_alignment}" +``` +The bypass arc deals with substitution and insertion errors, while the self-loop arc deals with deletion errors. Using "--show-alignment" would print the best alignment during training, which is very helpful for tuning hyperparameters and debugging. + +### Decoding +``` +export CUDA_VISIBLE_DEVICES="0" +./conformer_ctc2/decode.py \ + --manifest-dir "${manifest_dir}" \ + --exp-dir "${exp_dir}" \ + --lang-dir "${otc_lang_dir}" \ + --lm-dir "${lm_dir}" \ + --otc-token "${otc_token}" +``` + +### Results (ctc-greedy-search) + + + + + + + + + + + + + + + + + + + + + + + + + + +
Training Criterionsslfbank
test-cleantest-othertest-cleantest-other
CTC100.0100.099.8999.98
OTC11.8925.4620.1444.24
+ +### Results (1best, blank_bias=-4) + + + + + + + + + + + + + + + + + + + + + + + + + + +
Training Criterionsslfbank
test-cleantest-othertest-cleantest-other
CTC98.4098.6899.7999.86
OTC6.5915.9811.7832.38
+ +## Pre-trained Model +Pre-trained model: + +## Citations +``` +@inproceedings{gao2023bypass, + title={Bypass Temporal Classification: Weakly Supervised Automatic Speech Recognition with Imperfect Transcripts}, + author={Gao, Dongji and Wiesner, Matthew and Xu, Hainan and Garcia, Leibny Paola and Povey, Daniel and Khudanpur, Sanjeev}, + booktitle={INTERSPEECH}, + year={2023} +} + +@inproceedings{gao2023learning, + title={Learning from Flawed Data: Weakly Supervised Automatic Speech Recognition}, + author={Gao, Dongji and Xu, Hainan and Raj, Desh and Garcia, Leibny Paola and Povey, Daniel and Khudanpur, Sanjeev}, + booktitle={IEEE ASRU}, + year={2023} +} +``` diff --git a/egs/librispeech/WSASR/conformer_ctc2/__init__.py b/egs/librispeech/WSASR/conformer_ctc2/__init__.py new file mode 120000 index 000000000..43a85af20 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/__init__.py @@ -0,0 +1 @@ +../../ASR/pruned_transducer_stateless2/__init__.py \ No newline at end of file diff --git a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py new file mode 100644 index 000000000..1b6991bcd --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py @@ -0,0 +1,369 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# 2023 John Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=False, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/ssl"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + group.add_argument( + "--train-manifest", + type=str, + default="librispeech_cuts_train-clean-100.jsonl.gz", + help="Train manifest file.", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy(self.args.manifest_dir / self.args.train_manifest) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) diff --git a/egs/librispeech/WSASR/conformer_ctc2/attention.py b/egs/librispeech/WSASR/conformer_ctc2/attention.py new file mode 120000 index 000000000..e808a6f20 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/attention.py @@ -0,0 +1 @@ +../../ASR/conformer_ctc2/attention.py \ No newline at end of file diff --git a/egs/librispeech/WSASR/conformer_ctc2/conformer.py b/egs/librispeech/WSASR/conformer_ctc2/conformer.py new file mode 100644 index 000000000..db4821d37 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/conformer.py @@ -0,0 +1,949 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# 2022 Xiaomi Corp. (author: Quandong Wang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import warnings +from typing import Optional, Tuple + +import torch +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledLinear, +) +from subsampling import Conv2dSubsampling, Conv2dSubsampling2 +from torch import Tensor, nn +from transformer import Supervisions, Transformer, encoder_padding_mask + + +class Conformer(Transformer): + """ + Args: + num_features (int): Number of input features + num_classes (int): Number of output classes + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension, also the output dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + num_decoder_layers (int): number of decoder layers + dropout (float): dropout rate + layer_dropout (float): layer-dropout rate. + cnn_module_kernel (int): Kernel size of convolution module + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 2, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.2, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + ) -> None: + super(Conformer, self).__init__( + num_features=num_features, + num_classes=num_classes, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dropout=dropout, + layer_dropout=layer_dropout, + ) + + self.num_features = num_features + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4 and subsampling_factor != 2: + raise NotImplementedError("Support only 'subsampling_factor=4 or 2'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + if self.subsampling_factor == 4: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + elif self.subsampling_factor == 2: + self.encoder_embed = Conv2dSubsampling2(num_features, d_model) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + layer_dropout, + cnn_module_kernel, + ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + + def run_encoder( + self, + x: torch.Tensor, + supervisions: Optional[Supervisions] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). + Tensor: Mask tensor of dimension (batch_size, input_length) + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(x.size(0), self.subsampling_factor, supervisions) + if mask is not None: + mask = mask.to(x.device) + + # Caution: We assume the subsampling factor is 4! + + x = self.encoder( + x, pos_emb, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) + + # x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + # return x, lengths + return x, mask + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + + self.layer_dropout = layer_dropout + + self.d_model = d_model + + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.feed_forward_macaron = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # multi-headed self-attention module + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = src + self.dropout(src_att) + + # convolution module + src = src + self.dropout( + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + ) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + return src + + +class ConformerEncoder(nn.Module): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x) + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = ScaledLinear( + embed_dim, embed_dim, bias=True, initial_scale=0.25 + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) + self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) + self._reset_parameters() + + def _pos_bias_u(self): + return self.pos_bias_u * self.pos_bias_u_scale.exp() + + def _pos_bias_v(self): + return self.pos_bias_v * self.pos_bias_v_scale.exp() + + def _reset_parameters(self) -> None: + nn.init.normal_(self.pos_bias_u, std=0.01) + nn.init.normal_(self.pos_bias_v, std=0.01) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.get_weight(), + self.in_proj.get_bias(), + self.dropout, + self.out_proj.get_weight(), + self.out_proj.get_bias(), + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self._pos_bias_u()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self._pos_bias_v()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = ScaledConv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer( + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + ) + + self.depthwise_conv = ScaledConv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channel_dim=1, min_positive=0.05, max_positive=1.0 + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.25, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +if __name__ == "__main__": + feature_dim = 50 + c = Conformer(num_features=feature_dim, d_model=128, nhead=4) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup=0.5, + ) diff --git a/egs/librispeech/WSASR/conformer_ctc2/decode.py b/egs/librispeech/WSASR/conformer_ctc2/decode.py new file mode 100755 index 000000000..3fa045533 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/decode.py @@ -0,0 +1,718 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, +# Fangjun Kuang, +# Quandong Wang) +# 2023 Johns Hopkins University (Author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import get_lattice, one_best_decoding +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.otc_graph_compiler import OtcTrainingGraphCompiler +from icefall.utils import ( + AttributeDict, + get_texts, + load_averaged_model, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--otc-token", + type=str, + default="", + help="OTC token", + ) + + parser.add_argument( + "--blank-bias", + type=float, + default=0, + help="bias (log-prob) added to blank token during decoding", + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--method", + type=str, + default="ctc-greedy-search", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) ctc-greedy-search. It only use CTC output and a sentence piece + model for decoding. It produces the same results with ctc-decoding. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + """, + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=0, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_200", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 2, + "feature_dim": 768, + "nhead": 8, + "dim_feedforward": 2048, + "encoder_dim": 512, + "num_encoder_layers": 12, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def ctc_greedy_search( + nnet_output: torch.Tensor, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, +) -> List[List[int]]: + """Apply CTC greedy search + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + Returns: + List[List[int]]: best path result + """ + batch_size = memory.shape[1] + # Let's assume B = batch_size + encoder_out = memory + encoder_mask = memory_key_padding_mask + maxlen = encoder_out.size(0) + + ctc_probs = nnet_output # (B, maxlen, vocab_size) + topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) + topk_index = topk_index.masked_fill_(encoder_mask, 0) # (B, maxlen) + hyps = [hyp.tolist() for hyp in topk_index] + scores = topk_prob.max(1) + hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] + return hyps, scores + + +def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: + # from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py + new_hyp: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != 0: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + nnet_output[:, :, 0] += params.blank_bias + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="trunc", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="trunc", + ), + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor + 2, + ) + + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.method == "ctc-greedy-search": + hyps, _ = ctc_greedy_search( + nnet_output, + memory, + memory_key_padding_mask, + ) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(hyps) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-greedy-search" + return {key: hyps} + + if params.method in ["1best"]: + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + + return {key: hyps} + else: + assert False, f"Unsupported decoding method: {params.method}" + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert len(results) > 0, "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + if params.method in ("attention-decoder", "rnn-lm"): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + assert "▁" not in args.otc_token + args.otc_token = f"▁{args.otc_token}" + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + # remove otc_token from decoding units + max_token_id = max(lexicon.tokens) - 1 + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = OtcTrainingGraphCompiler( + params.lang_dir, + params.otc_token, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + params.num_classes = num_classes + params.sos_id = sos_id + params.eos_id = eos_id + + if params.method == "ctc-decoding" or params.method == "ctc-greedy-search": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + G = None + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.encoder_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_encoder_layers=params.num_encoder_layers, + num_decoder_layers=params.num_decoder_layers, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/conformer_ctc2/export.py b/egs/librispeech/WSASR/conformer_ctc2/export.py new file mode 120000 index 000000000..5f484e391 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/export.py @@ -0,0 +1 @@ +../../ASR/conformer_ctc2/export.py \ No newline at end of file diff --git a/egs/librispeech/WSASR/conformer_ctc2/label_smoothing.py b/egs/librispeech/WSASR/conformer_ctc2/label_smoothing.py new file mode 120000 index 000000000..c050ea637 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/label_smoothing.py @@ -0,0 +1 @@ +../../ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/WSASR/conformer_ctc2/optim.py b/egs/librispeech/WSASR/conformer_ctc2/optim.py new file mode 120000 index 000000000..db836b5e0 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/optim.py @@ -0,0 +1 @@ +../../ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/librispeech/WSASR/conformer_ctc2/scaling.py b/egs/librispeech/WSASR/conformer_ctc2/scaling.py new file mode 120000 index 000000000..bd0abfeee --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/scaling.py @@ -0,0 +1 @@ +../../ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/librispeech/WSASR/conformer_ctc2/subsampling.py b/egs/librispeech/WSASR/conformer_ctc2/subsampling.py new file mode 100644 index 000000000..2ba802866 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/subsampling.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# 2022 Xiaomi Corporation (author: Quandong Wang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv2d, + ScaledLinear, +) + + +class Conv2dSubsampling(torch.nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + """ + assert in_channels >= 7 + super().__init__() + + self.conv = torch.nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=1, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x + + +class Conv2dSubsampling2(torch.nn.Module): + """Convolutional 2D subsampling (to 1/2 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim) where + T' = (T - 1) // 2 - 2, which approximates T' == T // 2 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + ) -> None: + assert in_channels >= 7 + super().__init__() + + self.conv = torch.nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=1, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=1, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * ((in_channels - 1) // 2 - 2), out_channels + ) + self.out_norm = BasicNorm(out_channels, learn_eps=False) + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.unsqueeze(1) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x = self.out_norm(x) + x = self.out_balancer(x) + return x diff --git a/egs/librispeech/WSASR/conformer_ctc2/train.py b/egs/librispeech/WSASR/conformer_ctc2/train.py new file mode 100755 index 000000000..fe6c5af91 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/train.py @@ -0,0 +1,1115 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Quandong Wang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./conformer_ctc2/train.py \ + --world-size 4 \ + --manifest-dir data/ssl \ + --train-manifest librispeech_cuts_train-clean-100_0.17_0.17_0.17.jsonl.gz \ + --exp-dir conformer_ctc2/exp \ + --lang-dir data/lang_bpe_200 \ + --otc-token "" \ + --allow-bypass-arc true \ + --allow-self-loop-arc true \ + --initial-bypass-weight -19 \ + --initial-self-loop-weight 3.75 \ + --bypass-weight-decay 0.975 \ + --self-loop-weight-decay 0.999 \ + --show-alignment true +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.decode import one_best_decoding +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.otc_graph_compiler import OtcTrainingGraphCompiler +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions_otc, + get_texts, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc2/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_200", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.0, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=0, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=10, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--otc-token", + type=str, + default="_", + help="OTC token", + ) + + parser.add_argument( + "--allow-bypass-arc", + type=str2bool, + default=True, + help="""Whether to add bypass arc to training graph for substitution + and insertion errors (wrong or extra words in the transcript).""", + ) + + parser.add_argument( + "--allow-self-loop-arc", + type=str2bool, + default=True, + help="""Whether to self-loop bypass arc to training graph for deletion errors + (missing words in the transcript).""", + ) + + parser.add_argument( + "--initial-bypass-weight", + type=float, + default=0.0, + help="Initial weight associated with bypass arc", + ) + + parser.add_argument( + "--initial-self-loop-weight", + type=float, + default=0.0, + help="Initial weight associated with self-loop arc", + ) + + parser.add_argument( + "--bypass-weight-decay", + type=float, + default=1.0, + help="""Weight decay factor of bypass arc weight: + bypass_arc_weight = intial_bypass_weight * bypass_weight_decay ^ ith-epoch""", + ) + + parser.add_argument( + "--self-loop-weight-decay", + type=float, + default=1.0, + help="""Weight decay factor of self-loop arc weight: + self_loop_arc_weight = intial_self_loop_weight * self_loop_weight_decay ^ ith-epoch""", + ) + + parser.add_argument( + "--show-alignment", + type=str2bool, + default=True, + help="Whether to print OTC alignment during training", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 1, + "reset_interval": 200, + "valid_interval": 800, # For the 100h subset, use 800 + "alignment_interval": 25, + # parameters for conformer + "feature_dim": 768, + "subsampling_factor": 2, + "encoder_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + # parameters for ctc loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + batch: dict, + graph_compiler: OtcTrainingGraphCompiler, + is_training: bool, + warmup: float = 2.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute OTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model( + feature, supervisions, warmup=warmup + ) + # Set the probability of OTC token as the average of non-blank tokens + # under the assumption that blank is the first and + # OTC token is the last token in tokens.txt + _, _, V = nnet_output.shape + + otc_token_log_prob = torch.logsumexp( + nnet_output[:, :, 1:], dim=-1, keepdim=True + ) - torch.log(torch.tensor([V - 1])).to(device) + + nnet_output = torch.cat([nnet_output, otc_token_log_prob], dim=-1) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts, utt_ids, verbatim_texts = encode_supervisions_otc( + supervisions, subsampling_factor=params.subsampling_factor + ) + + bypass_weight = graph_compiler.initial_bypass_weight * ( + graph_compiler.bypass_weight_decay ** (params.cur_epoch - 1) + ) + self_loop_weight = graph_compiler.initial_self_loop_weight * ( + graph_compiler.self_loop_weight_decay ** (params.cur_epoch - 1) + ) + + decoding_graph = graph_compiler.compile( + texts=texts, + allow_bypass_arc=params.allow_bypass_arc, + allow_self_loop_arc=params.allow_self_loop_arc, + bypass_weight=bypass_weight, + self_loop_weight=self_loop_weight, + ) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=3, + ) + + otc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + assert params.att_rate == 0.0 + loss = otc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["otc_loss"] = otc_loss.detach().cpu().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + if params.show_alignment: + if params.batch_idx_train % params.alignment_interval == 0: + for index, utt_id in enumerate(utt_ids): + verbatim_text = verbatim_texts[index] + utt_id = utt_ids[index] + + lattice = k2.intersect_dense( + decoding_graph, + dense_fsa_vec, + params.beam_size, + ) + best_path = one_best_decoding( + lattice=lattice, + use_double_scores=params.use_double_scores, + ) + hyp_ids = get_texts(best_path)[index] + hyp_text_list = [graph_compiler.token_table[i] for i in hyp_ids] + hyp_text = "".join(hyp_text_list).replace("▁", " ") + + logging.info(f"[utterance id]: {utt_id}") + logging.info(f"[verbatim text]: {verbatim_text}") + logging.info(f"[best alignment]: {hyp_text}") + logging.info(bypass_weight) + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: OtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + graph_compiler: OtcTrainingGraphCompiler, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + # scaler.scale(loss).backward() + + try: + # loss.backward() + scaler.scale(loss).backward() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error(f"failing batch size:{batch_size} ") + raise + + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + if params.print_diagnostics and batch_idx == 30: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + if loss_info["otc_loss"] == float("inf"): + logging.error("Your loss contains inf, something goes wrong") + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + graph_compiler = OtcTrainingGraphCompiler( + params.lang_dir, + otc_token=params.otc_token, + device=device, + initial_bypass_weight=params.initial_bypass_weight, + initial_self_loop_weight=params.initial_self_loop_weight, + bypass_weight_decay=params.bypass_weight_decay, + self_loop_weight_decay=params.self_loop_weight_decay, + ) + + # remove OTC token as it is the average of all non-blank tokens + max_token_id = graph_compiler.get_max_token_id() - 1 + # add blank + num_classes = max_token_id + 1 + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.encoder_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_encoder_layers=params.num_encoder_layers, + num_decoder_layers=params.num_decoder_layers, + ) + + print(model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + graph_compiler=graph_compiler, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: OtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + # warmup = 0.0 is so that the derivs for the pruned loss stay zero + # (i.e. are not remembered by the decaying-average in adam), because + # we want to avoid these params being subject to shrinkage in adam. + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + warmup=0.0, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + assert "▁" not in args.otc_token + args.otc_token = f"▁{args.otc_token}" + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/conformer_ctc2/transformer.py b/egs/librispeech/WSASR/conformer_ctc2/transformer.py new file mode 100644 index 000000000..41e6cd357 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/transformer.py @@ -0,0 +1,1055 @@ +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Copyright 2022 Xiaomi Corp. (author: Quandong Wang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from attention import MultiheadAttention +from label_smoothing import LabelSmoothingLoss +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledEmbedding, + ScaledLinear, +) +from subsampling import Conv2dSubsampling +from torch.nn.utils.rnn import pad_sequence + +# Note: TorchScript requires Dict/List/etc. to be fully typed. +Supervisions = Dict[str, torch.Tensor] + + +class Transformer(nn.Module): + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + layer_dropout: float = 0.075, + ) -> None: + """ + Args: + num_features: + The input dimension of the model. + num_classes: + The output dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + num_decoder_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + layer_dropout (float): layer-dropout rate. + """ + super().__init__() + + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4 and subsampling_factor != 2: + raise NotImplementedError("Support only 'subsampling_factor=4 or 2'.") + + # self.encoder_embed converts the input of shape (N, T, num_classes) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_classes -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + layer_dropout=layer_dropout, + ) + + self.encoder = TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + ) + + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), ScaledLinear(d_model, num_classes, bias=True) + ) + + if num_decoder_layers > 0: + self.decoder_num_class = ( + self.num_classes + ) # bpe model already has sos/eos symbol + + self.decoder_embed = ScaledEmbedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) + self.decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + ) + + self.decoder = TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + ) + + self.decoder_output_layer = ScaledLinear( + d_model, self.decoder_num_class, bias=True + ) + + self.decoder_criterion = LabelSmoothingLoss() + else: + self.decoder_criterion = None + + def forward( + self, + x: torch.Tensor, + supervision: Optional[Supervisions] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (N, T, C). + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + Return a tuple containing 3 tensors: + - CTC output for ctc decoding. Its shape is (N, T, C) + - Encoder output with shape (T, N, C). It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is (N, T). + It is None if `supervision` is None. + """ + + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision, warmup + ) + + x = self.ctc_output(encoder_memory) + return x, encoder_memory, memory_key_padding_mask + + def run_encoder( + self, + x: torch.Tensor, + supervisions: Optional[Supervisions] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Run the transformer encoder. + + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute the encoder padding mask, which is used as memory key + padding mask for the decoder. + Returns: + Return a tuple with two tensors: + - The encoder output, with shape (T, N, C) + - encoder padding mask, with shape (N, T). + The mask is None if `supervisions` is None. + It is used as memory key padding mask in the decoder. + """ + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(x.size(0), supervisions) + mask = mask.to(x.device) if mask is not None else None + x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) + + return x, mask + + def ctc_output(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + The output tensor from the transformer encoder. + Its shape is (T, N, C) + + Returns: + Return a tensor that can be used for CTC decoding. + Its shape is (N, T, C) + """ + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) + return x + + @torch.jit.export + def decoder_forward( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id + + Returns: + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + + @torch.jit.export + def decoder_nll( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[torch.Tensor], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + Returns: + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). + """ + # The common part between this function and decoder_forward could be + # extracted as a separate function. + if isinstance(token_ids[0], torch.Tensor): + # This branch is executed by torchscript in C++. + # See https://github.com/k2-fsa/k2/pull/870 + # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 + token_ids = [tolist(t) for t in token_ids] + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction="none", + ) + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + + +class TransformerEncoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerEncoderLayer. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + layer_dropout: float = 0.075, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + + self.layer_dropout = layer_dropout + + self.self_attn = MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> torch.Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + + Shape: + src: (S, N, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # src_att = self.self_attn(src, src, src, src_mask) + src_att = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = src + self.dropout(src_att) + + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + return src + + +class TransformerDecoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerDecoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + layer_dropout: float = 0.075, + normalize_before: bool = True, + ) -> None: + super(TransformerDecoderLayer, self).__init__() + self.layer_dropout = layer_dropout + self.self_attn = MultiheadAttention(d_model, nhead, dropout=0.0) + self.src_attn = MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> torch.Tensor: + """Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: + the sequence to the decoder layer (required). + memory: + the sequence from the last layer of the encoder (required). + tgt_mask: + the mask for the tgt sequence (optional). + memory_mask: + the mask for the memory sequence (optional). + tgt_key_padding_mask: + the mask for the tgt keys per batch (optional). + memory_key_padding_mask: + the mask for the memory keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + + + + Shape: + tgt: (T, N, E). + memory: (S, N, E). + tgt_mask: (T, T). + memory_mask: (T, S). + tgt_key_padding_mask: (N, T). + memory_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + tgt_orig = tgt + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # tgt_att = self.self_attn(tgt, tgt, tgt, tgt_mask) + tgt_att = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + )[0] + tgt = tgt + self.dropout(tgt_att) + + # src_att = self.src_attn(tgt, memory, memory, memory_mask) + src_att = self.src_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout(src_att) + + tgt = tgt + self.dropout(self.feed_forward(tgt)) + + tgt = self.norm_final(self.balancer(tgt)) + + if alpha != 1.0: + tgt = alpha * tgt + (1 - alpha) * tgt_orig + + return tgt + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: torch.Tensor, + mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> torch.Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for mod in self.layers: + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) + + return output + + +class TransformerDecoder(nn.Module): + r"""TransformerDecoder is a stack of N decoder layers + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + + Examples:: + >>> decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(10, 32, 512) + >>> out = transformer_decoder(tgt, memory) + """ + + def __init__(self, decoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(decoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> torch.Tensor: + r"""Pass the input through the decoder layers in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + tgt: (S, N, E). + tgt_mask: (S, S). + tgt_key_padding_mask: (N, S). + + """ + output = tgt + + for mod in self.layers: + output = mod( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + warmup=warmup, + ) + + return output + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape (N, T, C). + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape (1, T, d_model), where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is (N, T, C) + + Returns: + Return a tensor of shape (N, T, C) + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + +def encoder_padding_mask( + max_len: int, + subsampling_factor: Optional[int] = 4, + supervisions: Optional[Supervisions] = None, +) -> Optional[torch.Tensor]: + """Make mask tensor containing indexes of padded part. + + TODO:: + This function **assumes** that the model uses + a subsampling factor of 4 or 2. We should remove that + assumption later. + + Args: + max_len: + Maximum length of input features. + CAUTION: It is the length after subsampling. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Tensor: Mask tensor of dimension (batch_size, input_length), + True denote the masked indices. + """ + if supervisions is None: + return None + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"], + supervisions["num_frames"], + ), + 1, + ).to(torch.int32) + + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + for idx in range(supervision_segments.size(0)): + # Note: TorchScript doesn't allow to unpack tensors as tuples + sequence_idx = supervision_segments[idx, 0].item() + start_frame = supervision_segments[idx, 1].item() + num_frames = supervision_segments[idx, 2].item() + lengths[sequence_idx] = start_frame + num_frames + + if subsampling_factor == 4: + lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] + elif subsampling_factor == 2: + lengths = [(i - 1) // 2 - 2 for i in lengths] + bs = int(len(lengths)) + seq_range = torch.arange(0, max_len, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) + # Note: TorchScript doesn't implement Tensor.new() + seq_length_expand = torch.tensor( + lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype + ).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + return mask + + +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: + """Generate a length mask for input. + + The masked position are filled with True, + Unmasked positions are filled with False. + + Args: + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad + + Returns: + Tensor: + a bool tensor of the same shape as the input tensor. + """ + ys_mask = ys_pad == ignore_id + return ys_mask + + +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + + For instance, if sz is 3, it returns:: + + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) + + Args: + sz: mask size + + Returns: + A square mask of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. + """ + return [[sos_id] + utt for utt in token_ids] + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. + """ + return [utt + [eos_id] for utt in token_ids] + + +def tolist(t: torch.Tensor) -> List[int]: + """Used by jit""" + return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/librispeech/WSASR/figures/del.png b/egs/librispeech/WSASR/figures/del.png new file mode 100644 index 0000000000000000000000000000000000000000..38973980bec204a74d0c97730921c7e29c4ff9c2 GIT binary patch literal 14544 zcmZ{K1z1!;-!C9aN{Y0w2nZtGjetmZr!+_}-Ca^kNQa%;im(U=hgxke zA|fv#B0???wlOicG=_r{5B;2gq^zWY>$C69m=*wsk|~bTBw+*pE#QUFmWr}44tX3F z!rRWsHzlo_SlDe9S%nm+C6dwA6rUa|SgH>{))4*SSWf=>_1zlL3?ieOPITLA15?n!KET6|lA25!gjh4*3l$6L{SorUB z1m+$MfA^U6XrF8myerLu;Ts*n`3TF)u9+UO$P4&JPV$vgBEcDv6-;WSzD!^0qZy+z z^P@SFwmp1?DQCMvX~-1IfvvAgQ-~l$ZV!zyf`eV=r6=`CiQPb&6l!n|0x9A~?poberAM2~j-GI9bc-P9 zCwZ>;b;od$Z~F92Rlg8n7|!Nnxk`wjxnK{i;vLLfYB^DHlUd# zDT74zI~!VsXM3(j3?X&cXUYQuIaW0mT(#uqv}3u*Tr794wAdqKsdVAVoCDM5%f74} zx-3niMAUtH-487486Is|>8VkkJ=Os^(5Dgr+ z9~G|iGZ*7e36cRlwA|7hn#Ks!@4r_j4pfy_lt2EEWs+O+y`ii}Gl0O<#39M-X}Q1` z#|l@m1obCEN+7rbOjC!;2DmTJ;fAMrVNfhiV@b4gr5JZj8XP|X)cE#W-@?DtK zH+E|X?IfzQ3+#%-QtMsi6FPSg?+nv<0mZqq+;1 z2Y#2Ff8fK<@MrvP(@oZu(Ye@;xg?0QhCG8v?8gX-B&S%z?L`SjHbt&Q_77mrLKyM+ zDUD8j7%CM?9C{T(|DE_d{0f0Yw2>&7>K#N;!a^MH8_(!FMa-?V)YR0H)L8AjRC()a z?Z7z&Xdl=UoB@+GQmVk`PJcsXLrzAX@luey<*PhR&@)U)5~|<|Uqun9kgMQuK>H~&jX-@FfbOWCa|>iP27wu+fr`l8V} zX2k?WiJ9tIF5(BDR_v)7V=`l+aP;zp%u7cfXL}I|k9~C=ajB}V@~jH0g6D~}7Tp|H z+Uh)hc`SQbRuNN?297(#S%1Ag*4u<0pbuMO8YLtq>?SxP$WUQZnR<=z`Q)|Z>sneRSQ&;CJ2i`R>?IfHS!J8S4F32 zIV`aQ(x$3Rvxyx%L_D`VifQ~jX*`DYV)fDW7xgFgQ77VVZ<;8ZSejIu^gXR^pqG6o z7AIsxSWnfT&_0<(pLr_wGy?tW(@8=OLIF-i_7#G4!f~P#0;TZzh|2K7i2cMRX?1FI zYF-)Zp=_j?*YRYg<*lTXAErOp!K8 zHVp^8p@|LOpd?^dfo`+N4};CzHzoBY+LQbxP1e%2h_xB>Zu5?F(7o|JIsRk5^A}hA zPo}AVDWB?}TFBAKg_`|vE>gc&vA#lzAdA9wnV?LS%%S$8j`-l};IVf6Kr+-jtYyM-<`so+cHh`4*gp~B zk*VOa6YY}f5X}=C5>`@dXcoB@G1YKmkm2Li3R)XA4_y}@A+7ALqOI=JGijC?wS5lU zMjOLV*A>`25dM z%1S9qYKCWLCfuf014QwNlBcF3(`Xeb_CL++e}LHIZ^?K29Q%rUZ|b-{9vo18k`k5T ziBD$GX>nj0;@nR#PBd;bF79)TFV$G6^=mR1k~A}3>77tLxOO#7ZA!;aHk$c4hW;ilW{$?T*z-3q=Aea&kI z##_a9I_PEWi5#9_gpK&jFmyzJwM3>)CQIgwqrpZ%y+A!#ZKlO|FTRV#L_=ZvsyM7h zm$sH}u_>ZTroOnAA}`Ce;r3K!wnaHdd7=2Qs9-U(-0&uHB<(t_*2(o(!E7$Li={W9C6Ts84 zovJ3U1J~{EftL&KlXv=ePJgg&hei2)9sSDaC3AkVP~K`-+^x~QBTJAvH6mm|^D%%Z zfdhjxo9%?7o5MK`oojk#Wpwm+N~aBy>{e<4PqByFD&c67AM2{=(O_gMZic&~z=enR zPS(VGd;^9#b2O`+Mgo_EL#jn7{F&)pQ^p2e0msgZ-k!puLYYaD5_YF4XEyEHmWAQ1 z;c1tS1F<8S+!T5p`})B-aEs$@&DpM$GnaG8e42CZj_z91)%?MH)7iyYhM>6j?+yLE zj`H_~?q&S%qVQfM2#&ZY-Zb>=pEO}K)_dKY>|u5B{Jhwd&(K4=|Fi8YEqFvy zTSXi=8%#K%(8oSbxNxJ_`RQ(oU-I{TBoOXGb1(NK?BG7waKd5t!6DaUw4O>%yPqro1N+Ha=7E2E|d#t*FO&X)+YF!Iq%g9&3Uf$$-(0K%@CoqOZB6>)Vs&r z_{efmwdR+=KK7Hbx`c_04BTtr9vKc1o(S$Sa0d?@0`SEDxfg?{fqV3)9sv$6*c=Y= zfkqa%-v1!LanJMj`Y0v{4hbmX0Ef#Lgn!iVhhHB3yNBxn+Th+QiAYEQS0zKRv9Yzi znTW^l2xy_W;W z=fVTrS{Xa&le<`1THEuu@KgMu-~sOMs~IWC|ByIX@KdPE$dilMfQ`vH7?>ECCmCxZ_fW} z`F}Z8?2W-9HdcU22Z8_Tt_QXM-T1)B$9V7ge?{VNnE%uQaTY-3WBgB_3822;mt+LG zk;Gh7P8qlYgY3`G5IEBS$Nd#J@ZIoq@F#(*n1twCWf%BgX^+pJb>4L&5#u$K3AG_K zM5`)4BUVsW{$>66@rD1T`qVtBIze6tisW=);uOJ`5B?OeyQ}1@0p`V3w}gh2hO`}AJr+VZNpZh-nPlItK1}V>Q;k;Kx;A)F}2hP6s z1_6yF1!IVa7fBE%S$d=Q<=K7)dkU>3LKIj306g_{c;kAy0%}SfXJLd$%}$0AY_{4W zp`Uy#Mj+uAO>s;NuEJ=KBSs>9|{lfM-7}`sL1H2(w5KdU>C;9=@JOIU&DW)eY~rvSO`T zjwbiB!MHaKp0)E%l~RFtNG!t57y3W)6ik*{eO=c{1FI9m0V^}rN_7EqI6?(SJ+Atn zaoDw}pas8suJN7Ca*Am{C8Lry7jk#dO=hyupJabJocs`ukXF7|Ld!0{rpj{+yCY3n zec%a)UV(-}#JvB8D@5o5S^ns~m6T2frPAa^1~UcDVYu0G@B@2XlutmT$1g{dKf#8_ z4ha-m<2b~ADVA8D9+RU|s9G*6j)eWUE8wwV8A6nx_utJwbpISlD^zp79Z?j2V8rtE zKo*A`WqxCN=tKETL<%(S>_b9Ulc|CeOG5%Wwwq7!4{8x*$YUT4aTX-dmzmlBOaX>O zV^yOg$3Ig(`vmZ*rt&2I|Joybw5z&;6p?TqM#?BI#I$ylb&~MipFt%`&Vhh)I3>R= z^v|Lf2+^~!^{i?>d|6Yx2cM8dallYB73zON4sjrmgoKtshKF$)jV$^cAc**`#~b!% z{DlZnB18wwrD-r|lA_Q*sQOZcNV%51UR)XaKvM*i%U6E1{T<(;|Et{+aV?9TBaR31 z&$)CTzJ|}~&_S99z2(>tNcVp^cZd}-C&qwZybiVHDSTiO0)p80u8<}ApO6F~f{oXX zPX8h(N?rj`%L<#2h^0=-@VcZ$$uS?xF*)eOm%A;A|4in6J#u}0G{hZSx#WkIY#guA z^``oVV$g{u3qjM(>Fma=O$$5jD_J)0yPH~&U{QG&QI1BXiAkE{L_QJ8-Stt_#jnaC zaE@ta;PVOtp9A$+eR?#3?AO5baW`?+TNGX z8(Gy-2@EE{IA?I%_@N}nZkplgbm&}Nls`+ALqn~w%ZPvQMDu=wy^6qyBCh*uh}6jO%xt?L_w>*F!G0dOBZeqXtyE5!E^QgRqWJWr(T=BKzk|=qoKWb+VDrsyX!A3iX3RO{ zltl(20_F%%E>cMZyslled}ue;OGJV(HfmV8poY~c zd=(3tK?n~k+e~+#Ir(^T80Y~+K#EQ{PApd@#iWqn?rOgU8si1@Q;K0Ew@Jw^$a?AN z73^gcn^KOngc#Y|8S92!>7$JSnOJHW%abKFJx}Ah{QH=!vQFD}JRN|5MEY^0-bBv{ ze*3Hp%B6j5cWQO|Ynp@c`eag%Q0irz08M!S!$vdY=t~$_r#qa`FeTR`9#oWzdVxw; zB(GeR;c+M-a^jZfba!)>GA1WD&ry6~u5ML5E*UcX=*`hYu@*6lA$tbDEbV}%@P zWf+uOc)uDA4iFi*e<*+mo-g^ET6K~uJMg;o5hk0Oj-v{VLKJ;E;CxwbSHVR1wG4fl4Idp0ldiO}q6NBe30w^# zdA9T1JkfzJer3Ptc{YxwwnwJ2nRNo)OlO^MaIV!hr)yk~*ZP(6JR9rRxSAfjMQKxw zPH9l7aKM}#qi`UM`>UvI51+C~Bb0eQ44uMjHG*@7V^Q1u5_q<((lgoi!->-TF$Jb> z&`mV*AXg|~MaT`#Ixqz(IrThG#&dun+_RWjEDqaIE1t}Pj7{65Sk*{OjE=ktFr9yk zL>i~&l*7tA)D0|b%!C4yXkeBPp{fOgmBg5OYTG%KUMr!ml69Qt$A=Xjt?q8OqcOh0 zUU2nJHV8~jw5pjl+2~_zsfwO6ij~&b72@0 zdC{WhsO4K@K+;k>H|lXV@9f0+5sGyrSzdqxUTB=q^Y#=$K_lh{e>0!1$lh&r7?pT6 zZG}#?q(51vFLK3uTuLJ>>prquyJ(F)Y<&u>ut=f)J2JE5Kk}8VuFW+GnGHnSD5TkE zxi?c|jf74J$IU<*%Ai-poO8DRRx#kMW?8wS%;t#&5oBns$EX8bApZ&E!Sns z&1nkv>6GEqj*+?x&UuHiCCMy1WY_JRiVh6kM5GrkUp}ONF7v;)@7fQIfhj@PEq~zC zagrMlE}7M?h_K$|Is9Yo_nSOUzsy`{)g4COh4D-@PSM6dUanB08q(lqfC+cWOOz#0 z8olpsd-6z}rt~+;yI)kch?I)SQWscftAlcLIXoY0)1KC?MzS8CVva;|tst|y{Qf}f z=w>$dGEG`1?0GPxW5D#&%WTHx^Y(aQVZ+y;QY+AAc4}kT@?2o^+wqSopf#JO6V;{# zXqC;S2-VNLVPFTf#=ges_o5!}CKccRrC_g zD#`%&(r%PA>{SGR!Np(o_xy?|Ih(%|EhVR2FeSpN2eUI;`AB;00JZA~&&?e`ma{F1 zB~F9x0dD&7dY^YQp7Ah}JE(RLh`b2KF9hf2PaSnrt+{jptLIJ3$Om-m&B3(T77>ku zK#vU1Q)cDr=Z0MHD8;K^rA&)41l_l?6VOFy()f@1Sx*AytxDsKDM8Lyy<1tIaj319 z1`qm-6LdHx6W^|Eomu7pk6e@d^-rF`d?`Jc(p}$)`^mUkEq{~w8Or$W^^u$iPiot0 zIP0LTpoKMI0%XiRZP{ZGmGtq_Dc#v9{vnNBrs3Epx)OqQEaTZpd@x1ZaKG@ck@sZm zRtsPr!k{PVCaFgo`G^B zNK9!PtAltNX~szzF^hregl|oXh<48og&Vhj7=D!S=GD*l*vS(2vkJcbjhR}GCpVjv zr*E5~R!L~3nVcyO-lp?+dI|S!9ud~Q`L@{>>NYf}pezY~QJukn5$|4|;V2TPaFBl@ zWB?zmU1gHo^C<_jAmd;!{v-Od`&#sS>GBd{r6{%D#qg`HrYzQ(TUYLUl`tuok{(2v zqBI@HHS3Cy8%Z9zl?F8Ilq{HHr0y6(>@D`2RM!wOpp_VF0=|aUX zA3~VCCWqsvw+pV`$?YmmTx^DHf=7A}2rZ(J7!ij~ew*^D$TGHESG4UYh%8p=uc#me1>2Za0xD zN5-4)b4A)4L(Ap*w(a+Z*jBkJJp)LNG-ZeqC+T`zjl}}R^`^j4(yDsJg5U8Q&ON~U zXas~&yO0!q)%tY0_#2MvaQb{AdEj(`w%oDw?@Fs~3AjkZ*LEB!mIqFEcyw#OFa?8t zGTS|Q;x8v~@x3JY_?^x6q@MS@SkgzQg~l-U)OIEj)ja;VRUB}OX3Qu7rgyTnFzfw8 z=31b}Nj+@vX@N$}o5HQb5a*DFvdJ^4{1b>uA!52Tb*Cw>`}N7KLO^QWWu)M_%&b}V zV6D{kC?Rg@`Jkm43r$UvcWKjUg%W{cM89%+rc5eM0;I!NKU=-hWFY27@>h^gGv!(} z-I|foA_(0&bv~82(_~o^4?KzsjYuVKCSlNYKmBFC8A**T4SwtRK|su%2Xm{m<#y(Y zDGAH27$}}A3l+LQ@thz((&2DfV~+`$Q}-s7*;TTcl^E)&K57>k>0Q#)1Z@8_3a76< z>5Ta?K)KyKd*=D&ZNe{781H8)xn1aB)wj9TPR(&tZVFut{9<5_RF0c9OiWBy<#`q( z#gNa5HuL4JKRW-6NM)X|7ckkU1ix^DQKw`Qa((KV=>DR3agJ$u{fpI5EfaV3n`|Tc)5hF`$v) zqsHlhsaY0(EKc=40cBswDPzu8!_*==&p;k#(Pto244cL!C~-<8xdSrFh~lPN_n2Ob z*Uh0UYAh`iieE}I>yOM(g}~HPL2beY1@f!ChSIdB^fnN@=kmh3vsB=UmjgMsOw6c^ zoi69At`goPT}g(Ye~8I?P^c6uQEV!G#4koCT+Ix`+tHi#ZphM1Tlf{&h&~mYa8U@X z0yA-^iSJy4AKi%iPETCqwVdVRthqH3Q;#xErJ|D zCA~9}!7&2GkFsqVPUyPcNDGI9%(gS;RYg;u-Am^<(>loB%;BhFDYhOHl>?#7G7rXUb-Ir9dbK!Klrli|$6!e#H1W}~*3cH=+PdH2F zvU7pNE=afV4K`Cy&Tonil8)xx=BvGlQlWyMSNgAA`Bh+Hp>!03XtE*fIQd&AfqW+| z!ro)Yes_Dr(!%L=2H`M$9*Zu&Tq0SU-?2V|o~YTt9>0RIrX;ina|lcAV`*YQVayvE zP%hd_g1;$c*1iHD;w<3q$0dC4{~}#Upu~V88>j4`Z^7F2X{fl$Qt)OKN53% zpU*skVhz~^NXXs`9XG+b=$NwaV6DvQUKafhNEl8Z#;oZq5w ztF$|)+IA7M*z|T1sW9C33xt#9)PiZ^OCe)di3Mj}5yZ2Yp75uO-gl`CXd=rHhIBSo z+m2Dp)~DH$Pv0vcj5#Xy;T9;73(8!eax>D^)i_@G=HPL$S_&E0E>GMBdkg;=Y=BaT&tX zypCOJK-#UWii$@JfKmo1;cPlC0`jmeeVG%?05e4HGb2C zcg0s<;Y?;0>Bx16!VnV_%{>!;u5ppmwRJ$sEvwf@RAtG=<5COEMSM*N-A{TD@QhsCq-_5p(%DY2%XTNw!8^L5$rCmAL?9k&hrE{1%{Ty#TEHx zj9Dt-eF68Iv*6WM&(-s9t6$1OH>IdCBEM#weWZG`<>2~?G+27bC{SjNw!5tbX^goq zE~A&w>GF7MnsVkg0f0+u&7wW9Zg|axEd~7Rozj||)9Q{~OZ5aBB6c%_s1~538!Ue9 z)B$|BNZrxKJSKN)-81WHn#fne^Eb;}$At|x+R0;f@>AcBsBn=IY(8vqO-QPbZjGi! zK633_zlT`Aj@k(|kaRor53_dH8nj1}Bq_mIYoaHOQ7vv799cJ_1aWrKcxTrGP;k-7 zijMAT%Do25^xqVnVA5Cy1K`9}J!U+e6#(&eb%F^t)}tL^P#~k~xQw%*OqzCZU zvnBto>)#?>kA7nQ2j@ITk>~PS&&`m&(s5>g^aLH|;&%tUqPe?k-{kUwgTo!VN7Q%I zkP`96@wzdeD-VZ!`2bIAfbi*4h?j+^h7c{0<=e-IxOhK8h}st5$~{nDb9{d;^9;|j zyc}wm#swa_BH|yjt+9k)^RQP2D1S;=;*x-{Pl5h_g~4+oldoB?;g$#7 zn+%U!Bmx)Av_ro9wRb)G&S=G4Q_26juy)*p2ZPtX`xPRr4%RfUvc2y_?D?RAi1)Q@8W|i~(Yw)u` zIAPsqMRU*oXeF2TKyY!Q6EIBwYdsbX5qveP!3&XO|_)r(wf=Ue;Vvuv89 zOVvu^s4r8d9V!QmCar3s8F+s!-FVftQdB^+C;(r0HJYyK><(Kx!z#U|p6nFA!YkI4 z$T;0yXVj^OZ4ld}o>VhDzIGaa4c)JvT(eR6S?)&t(mnZi1ofKpsOx6m{6SOKtd)cTL6`6a6STv>xn2iTP=xKZ! z%~uEAsWH&zW~{XO{brUuh={?JyDwxjE<5`v$@9aA>ijB&7|vfAUKj7YhBwYsB$AhQ zlk+=*B;NJKXUmBDrVCilut9(KyfQvXW^W@cmTG%VSDeBYJrP0du7geXMvI!!<2bsp z!d4$^&?&Mo526*tr?|~Ma|5ASoA_Lz;76`z!BhOk<8DmuP%1=;K(y3pcWneDf_237 z?6fbab?7q1Wx_D3=P9e{?Q59{gNxZ5nQVWKulreK?$hNZP)>oHA6M#1?o*L7p3YGB zVRB~us1(qeU0*cY0Sej$oYnI*HW%C8VZoo6XC(A-_pqRZ>Nah8Ztl~as<6L~UTAMRt)ZA6XHyssM_?bwM)Di>{Uy4H_ ziwnKGv9X=fT$3#=ievV?bTNGIhl|2XeZ9$dTX&sjsNm6XIV`e%T%vPRT0XThP5*ov z;z#x_Ak2Q8gTX^>%aw0l7wWN-5lakmN z-2#Z()J(fgXWuplIHgWzBrS-o@*C`zgtkU=WK?_HbQTA%?r!#3^t4X;ImVc&j~MD& z-Q9<_pHyQZgGwL60zE_pgsq1-4?emcbiP#{9IhQMk8BRqq&S71YPy=YAu#e2Tb23t z3#^8xJ}I5n+Dcc`BmG@=d#F_vI?oNp~swO=u!}^}gc(``gJMpFOX7lja*X!C_fJo$o9Z4xZm` zT%bZbzjffLg~-%0RW_%)M{#$}I`2k}LwI*)>uhp&Ycs+He=*%zOw9UQL*J#Ra}J@| z7$OM}agVxJ4qZNBMkY{!(W6WbB!XU6b#~Q;z9#{_GCMP^kx$-AjslwScL~+>>aNb zQn9v<%KE!`F&%Hdk8x~cZ(l=a`Y{C{TEP1t3xnBQM+@Ksq-yT7K&(R>hwX(L&Cmm+ z_b=+kWDn1WGodS#e250`TQyTGRl|~=nz_q#axVIQ$t6s2D|0HfvN2<}uQ=;!1 z#{vg~Imke@2%f*1&LqcYqIY-;U&e6`3C#A4^hh<|s3z|a*V&v#ncsbJ97m;E+YsLj zpjAQ$XN)2MrD;_mH&;G_pF~!-nC-#j!xD9`X^z-0ky%;)2)xq#GNt8O&F;vc$DjkC zfoihl;%e%!VwB?we+kr%#Qn}dRX@75w+TgqOLkyw_a;Izu7y|QhVZD;~-D@XZr^J?bP z$E@>Jvs6HcX|<~+2Jd9lVfLN|YA3`V7I(MlLkHTW5mM)`lx;CkPzKShREM01B(z86 zVJ^=kVWwNkRi6%4{qvHLS_Sv3)qbauBG-v6CA#lLyPFK1W;M-kr1Z6qeAST>mOXd} zvu{_KX;lbl%*7ZjGdI23nNX2(Ksai#AH?2yomZ$z1Je6~)?YPpE4tW^w~kDT75cVS z1>tJZL1{Q`D&N&UDO*XIHObZc_(n=y-FpvO+j44fr@pyco3PlGXN(KdOL@y+*)D9Y zl%q)RXCSk)@rLRbv`=oamH+fN2}K=OV$X9i!f%f%J>1RDZKsv)GPeuxBw+(>vb>~A zz}mJZfKpR}^AmV+!y;mulQ@kaVPl>x7)zQTsrN#$G=j?0Vv9H*HK9yHWsIou)UO=%N#15PlH}EgrDKc&dAS zzPeNvcGI^p+(j~bej4>PD*kO6@6|&G|CqIrTAZ~;J5RDFIXVowkJ#o`U*A-<|NNV$ zbmO6M7ir?A2?AG^O#6Hqb@`WyYqj96tq2m3T)Y;u0Fc&e#o+K~k>)SV@7eb^q&VWU zDd1<-Bu7uv&Qv+3aoDX}Oi1d}1wlUs;(+ZBxHghZ`Or@dYlYvR*Kd$?kMmJ3bb=s1 z$E`NZb3?hkJ{de64dQd@Y90s1B*8DPzKX&EHY2waS5-EQHD1*gCSxL&1#ydN z#u)nj%+e*#b=f-SVi*)BOH;NdsOSP9vzNa(OLSHyOjsJ7mjk&=G1hTdS2UfG=PnNW zSd0$W$lDa?@{wJAV$x_3vyd-Rsicf*%MJaz^}tKM;9Gbk#6_K%a=eHl%>LS3Ii4P^ z{?@y6^OPxXr>v4w$Jmm~_2hSWho+7vwU2C?*v-7EeEm6mu>6I%vCi#IK2B?ql1^*l zHuF2o*@1sDG@R>IE+o7ERMni;Ych#&mPx&Y(F;eXOxTSiJ8FcGFzJErt{e7^7I#V$ zFyrabR+bK*ePwinBoDajHX8&cq{}$S-pu}9A8rZ^KM~^odY(p@O#xDI&cohFKl9_XRr*S6?+)qbrT~{rx`U<5m#fG z&&^f#xj5eI&J(M~J69GwwA9w-jPJac%-*Mk-WAnw*wg2;{X?KZDQ?tIh#jC$r=ubtndGL>p)0lt=>6#G5EHJ53v zfwT|J^$wv&><#kye3n^H1KIyPfHTpo!8-Kx0L~F&%%F`s@ER^^C~1}6@w)tS691>;iN$=JXHa-woLNr!b=@?6j2bNS{tbJOXi2;RX1I!OQF-nfR2|N&-Uy0R*uJx z_BTj6MJp#Iw@QOr^=yeeE5QxRQQmQ}4Gsq_Jhd{<4n1%_PJ$P!EbY{!R&XE2(iVR9#4<&~jvn41#nst0{h`)dJNDtXiE|{rH$% zQ=PK`;;HV$IY^`Di0Ri+BUIDW(!sL7=nuUF}T%(l+H-UGWOVTj8Q4cmGjqhNgDc(5*U}7 z(~Y^)=lWrV&-AqE?c3&yevS<}Y?m;ltYuMp+;g{e$x`$FE~KgMoVaN|`>Su7Mmh0c zVXHu@O@PIuaV3p}prv$4f#)z=yNK+*Y(8CQr+Z5)8T}@p@!1bo>WB1A=sxEohs-^B z$oXDB}}?{7{S7ccaU6AXMqmDD}7;5`#>5XQs1%v>Vx>f38y2V2R}$+0PVkh%AfVX6rgDRshsQg7&e zG76-;wJSnye~4Pyfdt-yTg>A@AL#xgt3Cg)2`m|>hpK96;BzJ|4M9Y-Qxk+;u~3xUDp@z zsU=pkPp*2LkhMe}W7s0!dCR|`!Hf%L@|~f@Ogip0SiZCAqOu$bm<#$I<+?!_bzP<3 zL8;@1qNOCXgBq;Y6;&!>`>Ex&QJhz1DjAiTq++$~_V==Ge!-7~;QY&m*g-%r5=P4g z(*q*Yn`KC`;MtD2@1fm3^Ti(bs7f`hXQC*-MD{$Kh`m75X4|tMM@DPYgNsk!;(Qzr zW4TeTMo&Y7FfPq#v!ffBN1T@AJKF9Bm|rY|V=Obun9vf?TKJ~+ZO>g+nzuCGbeIfp zxjE^+Opn!AgnITD)jPSu!EgdR8)e1?(|E;PSfD9ixhN?xv@UuKxNfaO@X zX@lS2VCdi6`8Ed3wItI24I&&uHkA*6@(RiY|LrkR3LnV(_U_Jy@idAf1e5B^$bINo z3Tz?R%jH#r2W-V07z5esL;o4{K<7sPw)=;jQ~@x^=;7wpgD7#JD1Kf?@V6)BH368$ t{Nl*rK@=rWr1a26|2xWQ=zZ@V)x_ZYU#!5q??0oJcqc1bBK*Pc{{Svw|IPpa literal 0 HcmV?d00001 diff --git a/egs/librispeech/WSASR/figures/ins.png b/egs/librispeech/WSASR/figures/ins.png new file mode 100644 index 0000000000000000000000000000000000000000..2d0e807a971cc89537582178befe3a0f0d1d4ac6 GIT binary patch literal 16947 zcmZ|01yo$Y5-kb|1b26L5AN>n?jC$_cP9i1ZVB!Z+=CAuJU9e*cYBkB+n=je=CR)%pRcEjI01LrhEi@)ogc`Q`D&|9P-^$1jc3-Qt+v zeI*SFEOZPbdde6Q?7bLw0<)l^g`6!unjkcC0Hm1_;$?H2EekI>IrNwJ{4e?fOK*Sl z`Yif%&vgpEROJG(%r3wJM3fXZEic|H2?WN@@zsz+fteAN&go{7Wv>lWPf=I}QQycp zo}(ivI{qRzWlrQkF*cwsdm}{b+?-$r2D~rMP8*h$c#3ORZa_T?CH*|{@|}T%vm_7egHzI@+jk*J=SQ{s3EjUP{@&dHpHuvt6XU#bvpw`V0V6sY24C%m6%fVOZ zWrO?X+g+rcfI}IbM1E$X$g24sO*0*nW~vC9>-~!-4a&q+CS6QA=jgokdLS!@!TWYG zT*_fX;`58l$OF1XZ;5O72MmCjjk5~rqq+Ca*d`Nljiryp01&VN*XC>8c3sDi=xWj*6Kl>5Nt~emozJc8i6ma-#jH! zwB8A+l7N*WS-RY}f_=dR`|)!qQ3g=tg0>aX0V5RvyDNlX1A3L~FqQd|XTFDo-MvobSLuQ*f5zFseMNx69# z;0wqB%9yEs!{W}Sr*I%9BF-TbB<{>pq7FkxlEJ44{}!ks+AQQLa!0M0*ci71yx7tl z)Ewj-T21hmr%*vCw%r%`KvO3U0th0slZX?tFEzFe4$kZc6sVo zty~YOv#4Lr6m1D#6F#9DmIzr_O+ps<;fhS~FMTzTB8h|^b+uj*gNyUTl$!|TJRy0JouU|*2eacWq8Dd#)ovwdTa!o?s`M`Oo5 zBGAWO#xupKCEeC3_bO*@;6^0E!e|t=VBOi}?&Uqah=fSLbbJD}{B&|jQZsq2tSu!Y z`i%v*rQIlQGK9>vrRY4|w+!c9E9Yqv_T*>cqhZ(KilL__t`BEt6j8EbvOLM@jQX7} z%;TIVDdwr>ZRQojuE|x}D~&_|9A;rMXC=y4^8Mkc4a(pdx^MHzBS$PTD)4E^QZfTrO(hn%gFSs(oNTV zpLnH+A^6cjYT-xogz-kDe3N{x{0&E|gMe14R=Vawr}=Sm|9cB<<@tw-=mrCtM!MDZ zk9G1b6^*3Dxt^`h*Yb;>Z9*EFW^O>tDL>3oJVCpMO0WB1Irx62kNT)Y#Bcb(Y%r(XZ2_p~LIK z^?W)?wj!B+FmiA`&AK1`DJX9;kI7H|_G+c(yJ^LM_P~JxcIM9sAsgxsA-E|Vh@1s% zR~!Qz?pg3$^9#QwCr>l_9H12TGD~?Xe7rVrCewmgH!Ls4Vl&Zlyj=zEeEbh`XC<*r z7?-T!><-$nJI(Tx`j@VxG(DuNS9VX~IJ$~SSCwWa|_zpH(IXemI1JW#Cqo>~o zp!@E2^c~YXgDAt;cG^ePvvGo|kA^w7!y8qx>#6I3oC%y84zepH?YyTe$G!(rKwu&g z4pF~Ev_PpZ-)-bo$&>TtPFGEz;i(~QN4k&Q9p-I?C7-E}-bj)ol`c30fujm4m;(}+ zP$Xo48yeW;V@bA`%9oOp0O>a`kwy2%(oSGT4xC^p!(h-Yh~KYe>I%f4R5!q(zvg;j zfh{f!^}D?A&X*}hbQ@nr=XVMAvJ_MOqP{JXyEbUzM+B+k1ucjNrRt%dHe4ktQXV?_Qb#YH8yc@b>Sx^{cY$!pT9WGJ*@w0$=>;OSfBwi{cd4m zVPt0dPv0O_zTdSxO4c6cw%TIWb|9aDYzVM$bMXC<|Npf7*W&+ZYW-J}jhpL#HUFpO z|JGD@HV240*nv#C2>dr(ugd@1_^Qaq^n2w0TW+7Rbu{Uoi#h&1`Xt&>g-lYlp5^=gb^>E^3 zpZ2lVowxnLyX@M{y}`2bVK4hxzNj&SzLJ#;4IBmePdQBxLW5~exhs^44uFpMT_}i2 zG=={*B7=W|kpl9M@rmn0hzb2J%A86UXM;0T9tAc09)B%y5VY`T{=R=qK&Z~{`U9XFQiAF#8tZ>QT-me7&)*Ey`NTprAg`^Da0 zEZ9_~eH9H~$nk+&Y;-W30^Jvz&6}>*(b&>{>LU?E6b*xCQ7uxa-?DnVJ;NEs7NoX? z{lqoB6IUic;c|1ld?XF4>xr*n=zFCW2t4O3i+i~k7qmEDYSvbmQ+)Lg@tXq5HQYG) z3=VeJp4LM9CgP3GxHwC(Jfgp@6#K#lfee_N3W@`DGlN4$N-`cHA^0Z*&=EqO&{DwR z9{m^fSI5i95YJ-YeWp&s!dfpyh-fL?CgU3 z6Ux6BTk*X?J?}TIcp?82IN~=HRLHWxG?GTE3DR1luAuUR0a-#DeWh1ZV&73CC^q$k z4}j_k|DgJSBEMx?^+NyJLk(zNn37FH{=eGA(c+?Kb0_x1Ug1#-f_u8r`4dvVd-aYI z8#s*T#I8&3-$@5e9xhbQ0%*39SCF8k{3|mfO`IGzENBv+tbo66D=NRCKqP|L{Ms)TE>vIgOnK% zt8`mEoiEQe0(YBvS~VbH$@a>a3?mG!1>tJ06aC zFE&02rnArKP1Svv8A^_9uv<)5*VL~!Nsz;}9M8^T=a*%u|Ez7_ymbhNIhxKsjJq;N zv<~)31v^ux!B#s;mi?Yx;C?sn=S!luSug@egl`tPen5-y{pH^FqC+RU;NxkqFN0R~ zcQ%KWrSke&9g8`AclJfw+K*qKFS|@&xMAedqh~b@Jo_$pC&-LCAJ2y~a9AzW)I#Nd z`Mo^d1Q97|=sAf`)tE|7)!S$~Z^lce7o-^Qtf@&86@MRy!7(*l^?&j0&+popR#jx* z$#jmzxQzklzgrJMQ(#7;lmqfMMBKL=RN7qsG(~d0T@S|#8=@_c;{&;9`*KQgJVU%8 zpePd%)x1}j@eNY?(1Gv!K>hJ@YX%gd2l25vs?+mG8tZ7@CxH+BAMq@S4r|3r#iKC` zHLDD^yWyDF?x4*ZY}hXy?s}`F_37N53M0tnf#X~o@E5&kiqp}@2%JksTn5q_F>&Ej zoln<`8OnJQX|ol&;Yi*uch3(eoKFuooc=FQ?km>9KcEnB&0+kS^6?`q z+BIh3u$Y9*dmU6v`JW%RhVUhsqn@Kfc`-vF5!%|f9&sd!kC6`F(Uz5)gu&yO;0eZq zqdHdShGZVsF*GhsK3>fkHsE=VWpZPlS!CE*toU4vLaZ_Bqfdrr+B&hK0x6+1~h}51jo<7vTiw=ZoaE-JnH%r z+ilq-nqfpk*UsEE*n7JNNwAJ7jG=@xOIh#-F}*s|r&jxt6*?K9^+$ccAVu=1=T4R- zq4y!n1!#d?JlPod#7V&8LXFY*&5%kT=s7tun_FH{!G6+q5xck8SVe#ZfZ+m7j?vTg zo@Gsk)#k3^wyd*LHxO=|7RWqWN3Voklh*Sgx{7Ki!sQ_(f!i;O=61 zkF7p@`0&2pf4w`@?fAQ{L%X>yjiKM2Idbw)_Jesifg5z}-8G3m5QokJ`|AQd{pMRJ z#5$zbncjd0pB&$7mi6X0b=B5cM6T1y@vd)utgw~*`Obl4;Eeky$24B7i6$6(xk>Qe zC+#=%rv0$iSHBjYlLpMtq^~yLvQbtLoh6 z&9UhA@8phqu-;riHm%oyrow95!*Oe&^;F@?I<*+ECo*p?Oy+sT`vd`TDp(ljYtpa0 zf)p5-<{D4ADs^_8irv%4Fwpk9xSn@4tINa$uPFFb=X?7V3vT!-AahP!7Q6TtUn`5M z?>PTWT}1Ksv}qxmU`!khm>_es^G?A&SneXwY~CnT%n|6?+HL*`n&-hM0={1s;n$x* z%Z_aw=clh4r@?@1klGZvw+>3NESMtNMTOSZ?a$Oekz4kSE|kxZZS!Q(Q` z6V!7xH+ARe$ETas#w*-*MeN!ihF*IGK}3@!NLE?A8@SSY6{TDxoxh^!MwZ=I?X=ww zW=Huhe~7xN^fWBm+boTkwj{-kG~}@>MarZ z-5T@Kt@@7d_&z`0xdqKRbUs-Gt6ErsJ^JR{M#>LXqGINSkk#kXl*THFx>Yh z`+TqV`S~)tZEcc?FimfkoIl+i){e`yx*Z)354hGkh2q<YknL5Bj***}8jw z5Q_6^g%*h=4x!Ed@?=?)V(Hc<(5!z-ftMlT&9W6&TZsl;Xk1V17MDj?k59WH`6kE1 zJ_Ezm)mQqYgzo2WONSum)=-fE*mi0A=9SSSK?8!os)*5LE{q)Jhqq*QX$J(B9&spFioZ?goTj+$f*tL9leo+_)QzJ zydweNZ-Y`NMaGMSps3V-hO5Bbfu0z47=tg{_Hfgn_2tj^P=yJjHY7#ngZT^`ec!dnq;Hfz zPe2~D`PEq9c3DcNS7Ay5wed!D&TGeb%2p0tSG7X2zE{`uf#Pd2t%sK9vMf=e@;4_U`!>Gb*Jch(hVrFEYw)G`-tQH;=~E0}=#fw%z$<3DOiFLbDJZ`@^y4|OiH0L0 zj8Qn`5n2R+VLX3T6Cpg-vD7a4v|P5w0Pu8Ub<0L>ecH0QT$-5L$_u5keeVuEH{3iT zNm9fjOAtXN%aO6EOAID3($*j%r_8^CvtakZXDKC=Vb$N*Jc zdh`!w*lhUQR3f)Nqfsf0XfJ7XAM%M6LTz(ZW@J$mh!Y+sQTnF^WDla!Y_6y6rjceJ za1B>$?VjZIp#UZ`j!{MjG`JBX5qS@*i61=nW7C_7Ab@A!g=G~DI8gz$R5Kg3{roXuv|b}%Sm2gTxh__CHP@B$)*?kDKs z*FAxu(Y!5Qb;4ye?T?{oc!`0{s55m6ODUU_{R6>`Qi);FU!&|u+ z<90Dlym%}oVk^(}PNn$yoZ`&JbOK%-xNCWT#M{>0;$V!%jc&Ml1u5u0zny*5z&}!2Hkw`T^<| z5yCaqcL|!bshA8pn^_^J9=KF@Eo>~YaXn{oqKGVVA-!7ZQk>ogvlU0~tIxkjs8PWH zP#ND=nnti^a@r3|PfyQG%*`9}@(@JwQyju>jahN5N$D<&U z?L&6nRYQd8$NIQns+GsFF6KNX(_oExUG5rdVXT!&gWzb6iXaBom=2`ZjtlTo_j5TJ z6UzTAk~%f*4WF;ITvW%~4R&T0W!}ekTat_FLj|A;!_B*@a!TZO+vsTn!XBh_kV0|C z5Vc~%*5yK^9qb3{M5@$dPxnOpmR;x44BEZKB|>L0MPc)?a(XyIlX#RqltSIZYdGaqvarBCVTBWY;Q^T&eUpW;7x%<|TNcW$Dk^wJBO!heE={h`H?Qr?_e z;Lyb4acVSGq?n^uM6nGr!^B}Ca4jKLko3NbqmW(}9075G z2q9zaweWjyO`Jm8JTJCqY2cAWJKRbnw`Qu1du)+Jn+8Z|^=`49-=6m2XFp;+FQ<%y zCQ3^3r2(@rkaIvt`HqZ4Q8_;DMUw~gwrH=r6M8%-^oi+0QfodT|M>DnB{>}r{h9EJ zn}q?uA@!9a^vq>M6@YK5!MNzbV=D5NHQBdXBqY;`g!V%=pQpoIWtMS7ZUO{fre%G7 zy;}*4XNHrUmeJJF zv3tj*JB|HdIZ8qv#y6y=MBrYgIF+tfZ$Fax468~q@72d8F}>*5I|+tT-Fxlj);l@v zb^hZ9zWwF~l8B-{zH$TT;>5FNjUaM^fwq80`jvv9jRUa*Odxsyb_9yzUu6_2F$oiF zh63s@IC(_|<|HCwp2ZI!;U_Kt@vg7>x4?!XI4#aT0KNqffW$prVK_H^L z3lm=%4q}|&)Cmx(PvjqULQ0HP0lBVX^vZBZVS*SMSq(pGj{g^DA`N1onnAP(t11TF zAEm&3lE3te7i^RY?H^_+P7Xw*m~BU!P`naFo+zLZ04CD0OdKz$KD>6y1jzeo@ zxBJyvTw*{79~CL&Klq!dadA|f4(P0}@PN`F#s)3^%|CwZ92zKYbAB$y@~_u67+|gL zN-_FB7Arv#q|nBv9P^LjF?s{qS^j^xA`rtQeIbJN%7_LAftW2gi9)o0e&dEPf%!dZ zDl!bT*HM#$@C%1U&qktA>3D&DjzUwv&luC5J}Wj2`NTu$N^foKT)>HM7%HEXsE6+Z zbU&Xd5IvJ$zS#P;c243`6?#GCqi1~43e{&TO(1JBKPU+t{YB3KqImjT)9ehN?yq$9 zbaibEkF}QN3Gnd*{o4NVpD7lgKEqA8%xKA!eds9YZzAY;D`U)MGP(=*Q>YP@%^|X9 zkdxV3z9mKm)0x6%l;%WjA**K_5%o3JbX&ylt`(Jw7_eH%#RY)#dow;;9%?=Cv?(lk z>d!RTIhI>|2Xp8_y|J|>J7Ofm;Bd`D?^%D@=W6}N%AuB>d#fSyQ^EEBviKQG(`{`> zOG8v^a(6Q0zeWQJn*rOY7lgmR+I)P? z;kxQ+%ru^9T{zv&yxNU;IN8^qA=o@wSY5)$7}hh=a^3|utmd}lN=I?X{(_fII{K%r zg&ZrK@#{l9Zg<$h)Tu8@pRlp2XjdP=@k;Q(Mgu?PRjE9%hJ+Pf*imh;ej`2-2Qk}o-)m0D2*Gm@@sq~r445r z%PPZ_{5W(*r#r;@OD^{>0#@zf0gZ zerT87O9wMm#h~EW77#D*`;VemtS8*j6dNxO+u5-XO5Dynu zt!H1QQW)m^{|3g6yqZBVGonrnv zR-s-HO0st&zoh@7CYE`lwy8X6KK~C63OHp*Bs6{WKYbPCFQiP6=7hn21{nZBN(v@k zNhM!WT53{x6~R@j^~X2hs1ct7l0Ie9=#aH@&}Y8N!lG8IY|2gO#4&OH^P$vlI8iB~ z|KQX?l%sR2)=79fY4N;{5}nkj0I1BRT8Akzvic7u4+!Q&jrrg|nB3sB=RlMTuz%kP zK;L(Nelz=rCI*bPRXi*i@}Hec(6I9?N67xhJ5I<9A~l!&eI0%3D^wZD??aMf=oJO6 zO8;@U2g)D#`=Or?{>9h@)Q#*wx+LjqjDNsr(VTM?TG`p!s!es==0Z+SMl%#HRJH+z zZ>j}+UKSlb9J{ZYV>I@XMi+n7)MvM8pNrDI*e(%#xuZIu%48Z_tZKWN(cEwBe5P$$ zREcIR)nTgttfTyl`C+DX?b>?YC~yfcaeG9W$nUA%`AFbl9jg9(Qnm~sgfrFQy2`%< zBU$e~!;HG&%Xzd2!39j^^Dn?SBh3xOe#a83?B*ih4~GlCdud-9e|v5Fv%VQGx{6JM ztP9>_yQK|#ZQJ$G8OpMXZG)_vhgCStVjM%iH1`Y>|I3S!*5Svqo7EY~RlfrRf?aQ9 z#CstkAnVyhj{kmDU~)8lUEGJx`~ARch5-eSmku)j9KV>h#Hm6VxuLMI*_Pu{qRT5hH66@^j}o2u3tO9I8AW2UaDO6eG+WIha>ue zZAi$a%P~!!|c| z)k(762ZHIWjY=t(Xkral!1Eg}iGL1IG($ESr4JC~$W^b)DcQBNz({l8kjka}v@;>; zn{`Y(OrOV7Z=H9&`hl)~980Oa|IvNH!Zr#Wa(Uf3PY9PD(_n- zWEgr$XPy5v)s`VQB=~Anxj!XIQ2TU!#}WPHAdKFuT{5$o-SpMLQ1WZWL9kE9?Y!m8 znq~Ij!qxX!f|>Skg5>&@_Fv0)7c-cX*xzeSVap9w6$O`TXj2d1d<#j($<<(@9=~#~ zsy!Wd8a`bP5TzwCog$o^CAuw0>yw;>H0dy_762z??NxOTQd5jcL*`-M)Gl7uj?p?c z2R*gDM4;@)H?9mV+{rZW2A7uA&YJd_Z}gWqXZyTA;JkcSm4aod#DZ~$%ADkNS#3P@ zqAvHj(>ryoX1odf?X=rLIoc(TZ7U#?L9!GtBYnK z?9yyKp7gHI!%5OAFgJr{d#tH}@5&QXhtTV0`9uEK-ig#QjL>P#kv;?0o*sc;GnzBj zd6H#08M=p7k%B&(@gK%VIp|!Pw4RSOHU;(F(t^qcWxxDbr%b13f#B?X{Dyn8;*FEB zT(|01`hv)7R(|oYwe?9rMZc<3psXt(Wf$FvYQ1ltLNW5rQxWOnVq(#*$ew=(KZqZ4&kxrHMoC3e=VU4k7z3o;OHP1Tb1cJ`#BhqJu zS!AbyO`A61vz)tLG@RMCvg-v|g!0l!@g!gC7Au z`t@n3R;O}r|x{7aLSC5f`MwnBNM zD>?aR9j-qIL0UUo-;O04(W@Hir;;ETLfF2GZ=ybTD0s`<13ku|mNIdEjJ@z%r0ERXdW)+oE&tNA)icN>nfT>DreZ3;u$pgKuK2mJ*Ine>25l5LYHsokgP= zLsT`EU=XkSgL=RjP3>1M@F`Ppq(i^5xqoWFTQ7yN)Nj@I2|cC84i|&lXPuxI_R8sE zmsCsI_)2fUoMIGg{S313+AcpaLIG*>8~?RZQDh${VnPdgMXLY1JomlxRCU3~t(*Mm9bohX&#}VY$A&@>lew#yr#Hr z-$G0b+Sfvn&V6#|DBdL)7kGJ;k}trLK9APsrtkkRya#nre z>|R>}Ms;F;b1;3Fj#W3f@M_|?@( zV{;eUr~{gy*nAd}?41E28}+3pik!GTVb#xZf9X6jg(&>Tx>jT-tSStaeTbf<(Xhnp zs^CC{{Mn{KG3J>V+)DvG@k!d6eB+5?k&BdD zEW^_|pvRzf)iKy>Js2r-AA`3#L&_ zN;(#o403}TZ?`4KT;bbS;5()cmWHVlU{m?eMQC&;zV%*40)1SP;U`#CM<`wKJ(p8i z_Pc9Lb(LP|he!-z`;!q=?@8#z?bBx`zCVaRel+I1j6-c0kmoelhsl>0zn`w|s{jrs z<<96cv1SiEBD60LnD-)B(8h?!i5oDKl*g+seM09A&dPaj*TDXk4TUP$BTXbhOx)hh z0xc~U!|#N>@1$a9)t;<*h!Yu2PausZ(9=+Qn~6cne%+nWwB4byU=cpehW9cxUX3^f zl4Q>BZuUWW(}El56A}WoUClkxtznBRT3ZhAbE66fc!1zOq*L!!+1x&*$|!jcryp)k z0ZCsjUMQBt$Q2ufU+9t=vfQKA^_Eez2{8{V0Tz!~7_{ETe-l5u+q z*Q|n)f4DXFV?fkYRhMOBl1#CUbbJOd^CMcDR2i<23qh7y7_PrRy?xQgoa_v;KQL@d z7q%4W6bqzjt6w_QX$w2sTm97{mVKEW_RuT$32!sgi6D6U91+k~eFGyKpJ6eSMRC(r zn;=dPTVf`T=)S~(dKr%1V^8B&s(It**HI)N-->tTG42M##b3AG^qD;6e&eJwtIA1b zQ`9K*GgkoZpiWAYv}!|&OX#svA+ARaC>3xgmW$H^o7=UtnVPr6AO*)I+n?gSWrF04 zo&J$@fV+x<{)cLykwhsc#dm7ypm1CD$l}@B6aukn!`(Q}elxR5#TUk%l(UEph`E z$+HfWNB8txJ%y7~>z=>!o{TJDSQ{5b4kAOKXO|fliRI>YA=J)+5t5ph-5S2Eea(DG@R>Xdvpx&V5q>y?>0eHyb@e z#$>rGhT7Vr&F+AfTSu0Ikmi0tbj`zG(e$X;+8;1Hn^o#eRN&=e!G%qtnDV#OZKE+J z7EUV;A^>L5-908`z_VEHLHUsDd3U+Wj%4wByVW{pEbV(v+_fRQn3;@;IbWRDsKM)ry#Y0i-)DX$}GK1R`>}9hI znx>SiKM{EdA$V^kC0{9RD(wpQ;dk05J8A8VeFz!xErVo|*!(0xXx(U(fGVe9LdFrb zWVhAic(Xji-Fp1Nna6w{D<+;)lWIe_m$r~mEV@lA&>RvIlR+6#eR?3m#JUU^4+VaQ zT(`ss6CAtbru)QxXN+xc|^K48v4H{6>{c zZqhzb8@@dyiJT>p251&)5X<`>0~2y&+C0&WxR_KDvn0#x8DU`zA&$Jwzf@061eFe%+s<4=~q2A8G`Qnf*7N5xfQdhWA0Z+AxxH0aod>=|XzyTFD zgpOx-2rDg52lhkeCmZQ!z<7Vmw4(lzuyfe2gg)tYJo%u zpej^!|7Er<)q$XTYJkmPdfym-Xzvs+Msq$+o{KB3Uq?f0C^IV(u#9F4pJZ%BHSEYH z7?iPr*dtbTb2QV@&Wcg|d~DFYSz7>cV7wsT_)~7miWcDB+~+4*$MZR5HASUnE9C?4 z5*d}ixs*m2Jw&TZG8FOss^`~qB%XbwPLE7Njn2cRsHKzVpM)j$Uwe7Q z%6py40ki=8Rf!+xN8DkCoy`m@hMW62)Xyg7&d`Lpl@{{n}Cy`OG3>s^z`Myhi~Q*;`Xvi+@=M=n_!By*30r3$M9xrZIaJ` zcBE>gu%|1lIR%2H8dGJjF5gTQAI^N`J2QxARf->Yf<1eDWw?ov zz^Tv0imXOeIiJN|6AKukAZ~m`sSv#a!`W3!`ITa)76K(BQs~5O0(s6XU+apoK{P#v zbBXp}p+L6Z2>|Ab6drLKT&6b%|2h-~vE~5$x@j7nV2#fY@qe>0Lb#yRK+#$0H-*2P zK89#ORGSGSJDtoQ7;zGyuC3cj%hCR%3551xm}_TOtU@>c?e zfPv#S15mrH8rh7- zEOPKWB#j8K!Zk6$I;}N)xqKN%5=g|MCQlSKi}G0)4#uaR$?Y9Lz+UqiyZW@q5l<3c zh$VY-?*#>1%jQ?Nlb5et>{tG^lss~Y0O(3(`2hS`Gd6d2J~mjC+GK~u4mce*w)7Na zXLDD3!QoL`^~)YsvD1wq)%5E^Jcw49)8eRcq8O&BkSQE z{RbPHixu9bhP@7wa7x4`!%J2O%5C2Mr?ka{G*89bW}+`sm%bLs>sQTYi`!EZqn-8L zDhf*=x^{+pQC1xN`5dOvE>{#$hFcUD4;P;}^q9bY(PcGGlmqICp8YA|Y1ORX4i8^w zQwbj`G4YQ3fl+nnarsei$nL7<4Go~3HRZ48LFX+BP)g3=neSh+N(X zOB*N>!^1g}RDTmh;B=sAsd3xH`kLGNoC2CFo_80!uN@kR2ShEw61h{p=9A7*LD@DY tf5NQSv{X5?QoQc^#@MU4lEIsow^dmNqsJT25&mR(4=0dPRt_rJaktvjr`eIB+dz=VEFD zd;`tEr?NWmp$+`wU^n34G2mkWeu_Cb*qCaW8Yx%;ZzIja#mCCU3pCR!$Z9CwrR9_W ze%n}Dn*v`lrpDIx=Z{EOIN93*Ez%qu+^p>9e}G1Hh#ADm@~0urS7z!0G5@hDYZXUl z7f*F|duc99HCaAI76VDyAFFw2>f~%`Z+EeH4ptsk;K}ocT|692f3%vIKD0CjCgh~$ zlA>i71MZ#w<&wA{C2(N}0bb}kiB$mVuyFhs!SaHn>WW5k>Rg&`wx+zE#xk1HZaS*p zC;xf0%Q0=;Wt<=m7K-*J-w9{p{$p4UUT(f0gPM5!XyNDJ`q5zSbUDw(wC4@#mYx?c z!~wkQ`Etx%Elo_FFGl;(U#~n}wx|sk#FM zkYFxJH$b?6HVYS98{moq_z7`#_`z;7OZRgfAZ})F=W-!F0G)tqV24ZC+t@pOr#P1h z#FXF67-)BPak95I{pk*`F~6yi8F1$Z-y9&nlHh_Wz+Q7Py_D$lUVoi6ms|w)gekD& zU7S1sCS5$veX-<=-N4Czaqae#Y+W>2{3KSF-5?jzZ2rr@=X>Pu;W*IeLSX*-%m0B~ z{UI|icCqtMtTO>N&c&s@lZ%DDxxF34M(WpAaVL9MJCpAU`2m1ZQbW?Xr7wwCl=f8D;-|VbBK&SI7$#WW?kNo}OoWkEP>;MLwKlIDR zPfwlS`}MK!SC`NIR-XP#=`aOW2n-M0|1K_noxFc28&0M+5EsjbKaax6{zt6-d&d5C z1{Hh2ru`~<9Ork=#pQg3z%MQi0f2PhTQ28vws&%VVVFHqNa zrNehY_6w!+Ly>W^|AV#8#T)!qc)2d8ztl>9+6rDOC^07|$T`rw9Pq+6{GDj^$0(O$ z*}7kt<4f53JM6%BfhGQ41^PF}{~dY$-1y&Z{0~C>Jr?Ts%e^GtpRp0YAMJwoza!px z&vONEIoqFxUH{t%=)wv9Ns-Y?aM23z({l1cY#jhiWM|}j{&g-T046$z>faasv(@_5 zKal)YZ~g2w{GV$3RkHpIJoSI7?FWhe=I6P0`!`e@=buwOzw+oeFMh9ExPRlk|63sLKjh{G?|$OuUo})5tQWrh&&>XRz)<}LaC5T%4+_)&J_i4- zIRBer@Lyi{(t-IOHdMbQ&wrGm`YrK(x&C*;;D4yv{tTr44+zBnLsZp2Cl>IF>|KUG z{~3Y!-|p^z(?I+mZh_y&)BYk5f02;=jV=FQ3B>;aUi;$!`(K*gKV0?qjsNGt-TwvR ze|97OAno~wME`xnEB?*s)BlQ3 z_uCZkkL~mi!2|ys-;C!{!T&RTv)}5WU#!lhNd9Sr@lQ+;{;?!qR#5(0^LZXw2h{F) zLhwTS{zlCI?H#ip+xP!Gvww2J=CX$ES5c#tlBN}xpydZj-@oUdIan_u1;3Fv^}kxR z`rpkz{}XlOrP8|CSpOWA`P&5N|D5(Z*AqaQqr^Y8w1xZ90{nwYTmIzu|C;{Me?)1^ zA6u^9SNHy2o&2!@`#*K^S5L0|*wvSN`*-yi|6$>lKe^riZu0!e@&7f|$uFkk(vV&1 z;!7a>FBES1PQm};pw=bq{`E-Q&uH-b-OEFGKW z>wkNI;)4I-pv9jC$4VT{UjA43+U`eLe)kyfJv85{`9ib zdEeXqN>#{I)YO)JDLgKlGh*Y7ZfyDa`KsT>Yl_4Ik-4EK?!^YJEr&QZuICnqP) z$tI6@#(Tw*`uh5W{l+#n%$=Q`gX+q{Qc{tGHFe}IG2h>a0(d{P(N`guB!yZP`W}y6 zc9#dMJ@>3&GlAI4AG8b-_cV;VqOv<9iLUb|MG!E(ZIG_T<(=vmIlhbY#_M1`++gqP zr}(A*Oy$MiWWLYS&yBl5;XJfwCp*}C>yw{tz!PoFxC6z|<@XVJ)^y4l_@MWA^{#7# zAByiq;$6RPd9Xe?Qhbj|p<&kgI}Llp{N%E_h&t?B;nEcDFF3%}2lsW2n7pM|~nLS^~?yqra>*{7kVknwi#r z-~iz&^1`xUyyz$zz*m*40LQGdEOJHGr<$3VttzV?-MW(mJ)A|^-rpZ>{QQW7?=?rM zN%$R?WgOr=dcKCew4YW=7j(r-YT}!6@38UeK(=q*BlLl-6pU1vDvUR{G94Wrs+gJG zc3mD2!)cR?zdcwx4NWy~tviUi!DTVu^=PEo%dOLR24jK*fn^P))KzhzIbl(6qj0)sk8u7iiidX#%iXlx^rfa!J^T`6g$iaeE$FuNJ4!Q$c>$e;_5 zlEGjMqC%?=#Sz_YM0I?Kx}(@J%>x4ide=?R(1d(o+#ib6Wwbv=60v1SFdy-MY#sxH zg{iZU5$_^?5FM2EnW2p~jDE^xKIFC0T;P_Bf=!0%D*!UJ8Ltk9!uPbqt}K$GfjxJ_ z-X=21w6wGY2r(uG1%VjvzbvIFaB%NWW;0MCszjE5y)jaxZnVQmYdcyX;}w9LyjSj(vC>+92IhA5{*s5e^wd0q5@hyuyhcGArINufp8|H_VPvfvUQnb4_W`Hqfm6t~+U{I4h8QSb|-ClUp9*w0b{>F;%)-ApDp!+&+-5za>GaT(AmF9+)(xFXjv-%y7Wy_o1~cQ zbV}NZ4br*ly&Y-@fMkX0(x^muc6E=oo}Drcv2mbjK(-u+09_XocqLx+5i9bf@x2Fl zsJ)ZP22eSfXtEDGxgBS?Wf&ahXp$6f;Kv6@IwQhwd4dp=#I<)_7J;=9<>8A2Q%gr> z(JCY@h^q)byDA1!<-Sc(7IrO!=gYygZ$Vp|=-{>a)z7sRA9oFHCtjFGOPhy^F5Z}%ZYfqS) zuAJ#WAx6;Y+0dzTfbc!Bg-Qa-p%MvrDO>+yd7pq0{|WyqbK1b3&b|%*8gut%w+yw+ zW10meF4|1*ZM0_i{%14zo`bOWncZ|t*;wL2-s3_ONv4~~NkY&QZBSYp8X4%NZjDvM zV4JW)5;H@&1P(f{3ZZlx4no;#=<~MQ`7|n2bxUSGaM)g&!Z4}qi5Kt!?BM{XFxjw3 zgI9HjV)SkhB1(LEX-nd|KB05+@kydet?-cM3_YQwf|8|x@n+V^y_KQ-&I(q8mb(Ou z8azQ4nBMq2+& z-%|lfl)`KZf5CvlP1P`2&%-6%xVdmG{O=Z0^jv|D`LD+^2i%@vD`08>;UG{?d+$7Y zLl+ky11NME3?kO*)i=j`Yckej%(#>kMnoNoO6@s! zqWWtF-dj?R-j90B4w2=3C>*)H5i@=EbzqCh2BkOY48D$^^Xr;wekn1-B;?Bw=c;W!aXR8# z%LAKSh9bg>lyYShE@da4S0E-&e8>ceLct>K3Abx;-!gjtoZSh5N`|tOR2C_4i@NxR zj299*J`M?K=~EmDrd%YPZCchSzob10G&JsK9 z6_J6rSYfF@DMTTewU08e!B+L?s%A|SmCvAl1hb^)Bk^E#9r%K*>eX!qTn!ggBNVVS z!vjnxAprrc{AaPb`FW`vOVn-Bmd$nys`RMxENmQI2En0Key1P`*s#7+d*#W4FU~V2KU_;214QU}a42 z8Nt_o#YXIiX229Z3PC>u9=`LABuau{&YwHpMfAY@8CcDPoNK`jxi2ziKLFJaTqQ?Q z@e{~01Wn)o3fzu=<)r6lOa?$Np1%uvH6eco8W)(&K?0DEhS|IwYObFl9!<=c9~DT2 z>0?KS#30Ya6I=%(kT!V2^x2WX4TQ0C5x0Jqn^REdxyRVq)rEKZI6ArjFkTIg^T?A8 zu90zZam(wIjV>Eg!t#oW`PJ3jZ+m*OLb1t*a%G~O*5kpx(Cyt_8a6gIM(s+9b1X31 z9$aKGn4?ju&9S|+lZ!_yuLA7TqPtHyHD222pkKR|hfcs4)~Z8{P)5}T%l`0m){z6i zm*eM_pF|aN&uwc?N)UR&FC{5C0N7l5PR_VZE&~GtbaZsQpi;ehCwv?n(bLmYjVcp4 zv8f0mHYKSDeB3;&h4-I7-;ZhbKH6T?Mi4{g;^IOQxT~y8WWqs9D;3jy$88;z6HN`x z$pa2^ZFu-Cg(6+6To1RkM6aInUWw)-adGiN%`)BH5p`u%3k!OGy()IZebKfAs3;16 zJ@e4ug{ry8A4+xN_d3Wh0ybWu)=>iL@`1(B!Pid`jcyyyHa8sv-8Zu;2R>97WUh{u zDb!ex-L|2p&<{o@a9UOm8U%ZLFf8vXUJ@WhE5uwx3@Tq+XT9FZF&;gH0xsbq*j-;m1g^NZG-`Rq6#w zBl!wMtAn|rT*6qT+Hr;rLbN#aZZ4-NvDxzK*Ph1;<0%Ia7}W{EH(I2s+bt`REA z20Pg3hh?|x&ao0BxSfk#{5wu&WANDt7hp{6=a9hen}%W`6AWrZ&|m-=^T7cNJh$xr zIt-WFbJt>)t#fV)btZ5bu-*{3A>6@UYNo&(gbz3+DwKx%c34AP_ESQ*Dja~r;kp0G zyc4&mvA5p&%j=NnN952Twx&%)*Nv&kNr+?P7lUR`9-o6rH>zcei8{wp^0QV%8hBf% zF@RNdz!OSbhyFJZJOHXDgvL+aFVlGuB64=(?0Is?q4<q|cJb&2wPal`a0Q&$@@lf+iAIaeS{#-3@3Qb$5UJ-uAy1k5vXQr=B(Dl8BfzpY6i z&|;__4M`OmJ;rQV1pEUZfI}kyCilV?OFVuK zjgcjI4a6-PvJYs}cIp$DIWBk6i@gZ!6-p1XR*k zKF|j5n#&GhgM6b%1=QAtzqrgB!yVS)9V>)s@z@^;ru!dK`iIL|y6|rc-V#WbjV3c! z?X03L6!&9GWO|U+)LhDTJ!xwbw{yBqLJKdBVSn6q``s5YOo@{s$3D(>qr!V)+|+h(-LV*;U^HlH>VbC~k%MBSk@w->%O~B}W80Qujok{SJ)xy#2`ry(kl~A02c+fZq@o<%$u) z^c(nwMEcEpd1QTo3wk;7set8E3&px6nXrmN?wuu{IT5~yN)b(wqvZDWBgBIaz1#%e zL&7|YjTHF?OHl_L75KFf-hrjQK3c1?4VC^X=5k`&;KU+_)54h_p=Q}irIDje$*7r^ zMd3=}t)FouWp>?!-{Z-VERIBVs~0%kZdx&)MDAb(*5{>Br-D#SJfk4CSY=GRjd#Bk zTp~x2W2$M>DvE?bm_@i3P-Kyfm|}4EK#ff0Fx*03*;iCTP_tHndkcsYbZ*U0)I0M; zM>k$`n0p)3VO?^|8zbher=d8)W-ku(Fq2ZTO&hw*|cRnX^X-B46V8 zxf06kRlIJk@hK;wBED2*LOy; zq$xg>=|A3H+^HmSiY7N74By#m_hA`m<>x6mV zJLY^QiW8_F$hGTtWLl(zRz9us+#|svoy_!HWq<(=Vv+Ng_r-g?w~8XK4DB*&SNiRL z`&=O34e-y<1RfI_pL_o~bV2y)(CMt!>v-vsI$wUzsRw(fOH=(6z9`BMo;`DAQk@?s zAjRs)4(qKZpZ6%vp$H5y`UX?Ki|DIIBBn!j6dc^~!Q|L~XCZbkEMw!oOrnfLa3BEL z92kl34OH9G0j^{LfLI1e?3*AT9%+LYI=W$$&^X-D_53g%PNd9QX=jnN)aTn5+g+ zHmk9^;y$v`!CMZWz`5*+Gwv*e3y(sQJ&7Ae zrQ~lnhs3Q%u{~Z~iHwRS+8V4Bs{1;`SaINdxJ4ZCxKS!S<(?Mp_}C~m2;63S+#iX$ zNVw-zP_N!_Y7*PV)0j)Ybn4-ZQ2sMhm9ZqALcipKzOJl?-PoL*49r4t5r_B`+DMqD9p%ILq$-l5FfIi| zTwbJI z*KOJ-IRgJV>sykAl~;i%`1pB3Xl|OCN&FFt^71)AAnFct94tRsKL>z0y-Le{uvaij zEswH6APjpRhqL)qrBRhPFK2Mz=#(h!^#nj*h^Fm#Bjzk#T#;bZC-P zou3q<2V(5@3h$z=#Zw8Aas*%X`{dfWG{4)mDP474p8735VVm#ik<9Gm*J)C83Wj9> zW|!VY=fT5fuLGxY>xnuJO2+3(w_F2U7O<8N zx8_Fx|GdHbsOpnKIn7kTCk5;m&`Ri&2TVz;C?XLVzNa4IjJE(^@%rt%SzGhnsm&GR zn80qK29(BdDBsXsam%JJfPIq>M8?bkV6ggW2M$_0jd^R^@mX%PNDdIw9sW?F)hUeQ z66zH6?Q^Y`i?a1ig%eYYT3G$iN839NVhOI!W{C8Vt5m1~))u%4ji=i*{szoo`9!9} zJEl|?)`ZdPH4d+%8scOtFs&&{dZi!TL_=BzVrZrNk8NBQdT`s_&EGu{%Y?#;?nxZ2 zH*Q!A<;iJOTc}FW;Na%SW?x+yuPHes5jduKr2tCQ{L0|2+I+IzJ5piLDrDaa6MGG# zv^UkwA-{3s236>0;Dv+ICR&lYt(IV( zemp#Sw^sMrU?PjtlOrQPpW}T3I-&E3`dK#&iLyHMi6^)NWm9_FXZ7>@7G&i6B0J;h zxKu`UQiZ1sX=!`raiE+eeGPGRSKHa9XKFQ~)1+i_iC586%Xm&UTF-_`Dq7?Zbsr*& zVa@~oqi*RFSwBqcTNxxI1d?cIQQ_g)Gc$&clYlE!1njrD>8-6TDu1r+9)E4<@weKC z!PnU;g}wF{V$_ZK%;D|<=xd={ids_}F#`{W0$7@A^7Iu!@ zjrsDOusfPQ8^){(nXp-}zkN=hJtNPw@cu-KjzXer`m7pz3> zCA+J=*@R**Jz7+wV`j_AJpH5P4?PT42&q?l#;2zhd_-F-z;}m3c(dAup23&SV^n|} zN+>yy6u>4({~53r!A?ZX+A)F;ztS`6&(IG8KJ(}|32||H0Rf$u@{^-IIPHQ9g{@kH zZDNclsUs$tkhXXFDd3@b)=HY}0SOFQAhtz7k|m5p!u0sAOLT$jxmyP&DJAdHDyV@C zmT0k!S3<5TW{RYr9<-h%#pAkUWE01e{ zno6kYnu~4+a2IITBAkv2rj+m3ctlcFlA3!xZ ztKKpOG6jGsq$>%HpsxfHDGXHp&+H;9x?b9Tj-WQ;Z@a4Hh4J{!*4#)kh-V_8BFc1YasV}oUWEc?dhiMbPb9$sR_QP+?o|@%Kn;PbAbaTpTZgXq zKDBk{6Aoki_;VppDr{cbBsbC16k^K( z<$+Dq9Wy8HGu?fJDgW-Nu5Z2UOd2>^y++=Yo~Ea&7|lv3Nb{HMD~6c@*A{P5 z>33}!2V$qPASHDHY9M@-FeD;q;vqj(z@F%K4HplgUUWrcDxdvZ5s~+6^~|3ZqV__$ ziGdgj-J$xd#SMdmo=J22PtT1*D_^$rsc30~@AiY%LK|psJWqjKW4MO;X5-%OyZ6S( z1|`UtRp24jc<dQJRQ88Mtbq%#NsR&uPKSK8fQ5%)*#w?pB&8E#e zTL=>AP+7~-oqEei0!VidPo0o9A-`$L=dwJY*DB2}``oB=7i!0X7IY7?K2g8BlxiOV?M-GoPgQ6s zRfY+(e`uDmvMT1ELV4fQ0RsF@Fh(gkO+fz0Kr|3MoLxug6_F)p-UqxT!XTA6AW$5O z<6&bDfmdGJB7pEK7ROniRhHI%EzvC1j_v(Y0mNzB=gjGoVnpkYgLfT`-m;Mrih`2c zd=Lo5Y8LOlvvrg=?lFkLu9VyoAE|o@e}rZQ&Yd{LlG3{!GtG*8bhY!i{()(+&I+3E zAhkcg0yh1Fb_+$pdOo+khqy%7>9@?{GVCeZr)P~J(S>v~7@6q&k*Vl<>@tYBlcDqZ zK+5P&KR5lR<@T1xHZ5ge+4C?-Ygt2g_c~OpS9O%`yod9%rgO4lo{xc45RQDBm^S!b zKQdq#!X7B>UELsD4oNV&%5}vkY=neA=PCjS*CnVDN-s0{1WLalb%q|Uw{kjyKUft6 zdOOKRQx)b&`~swuA*6HY+fi_9!>y+CW+`nj`)T0Vi$MD753ttQeQ zE5`!zDzL4d5Np0tWVBFKV&Npvb$}F2l7^!cZOI!2H#ltBSVF8aXyKR={(eYN3H^!~ zSeC`lgq%iu(PPrd&?|6C7{N~bTGAUkIlOWm~>>}5|x9cH^ zb6+NXxFZm~o1M0i0}?28v7r7h3G@73J>TG)_NL}Vlyf0qXTx(}8)X6FwfMsD-QgfZ z!kX83c!g^c+nq!PVF`#sgC)B2RnN%tie9-p1Ocd&Ty8XaLR%+?yf7!w<)(4>V}jGJ zI0^@JNu1NxF5de`T&52cMzMQ0xkj*iJ)eJN1kzS3(7{TG1O!K{2IqkmRULMg#w}1o z)_$nSRR35!xf;{0(}4Mb#QTqj4-xw4r_UO9wL?;z_4f`P30bt84f0+O?qPE4`3lfD zjLrp)l*_rkO3~StrkGt49`{%lYCg7k8Gge+4@jugSguoBN)Bn=P%Du=bv{hD8-LOy zOHu4RWGBoi^`FR-$<4=QWwBk@08~^4X5(F(9HE0s%217=VR>rfC32 zA-^4#opzp6lH8c@85@YzFBTc9GVRs#SJ*ax14}9;Pl8w!k|*sMeYz(vmo)@Oq!Uno zb{AJ<64;h~Z`AxFG=Q9cQr~t9wC<3ihHr+g@LGWJO48m747|GBuL6XMN829D_wYFM zd;>r^T~j%G<`(Y#((7UIQvD~#AsyT;TDFW2-l;RYfK)$9(3l3ak`5J6xFZbu17Y*` z^}@WEBlA+5q1n3|(*}o#RE*kv;=4<-S-abVeZ*jso|rUPqPr#W^}|~??!>)B8I{dE zuML2CtXPd=2m8H%E?+eZhqmP)_M(Hz$qNeuu^GID#J=hT^nToZ0j1YKPhT2{{^I%J zrI#Y)JNx|St{(FH-&b3Fn;5*o$S48n_m`BE3=2TFieNb1?oA0CDR^MZxRPq??;<0n zo923!{*1$C&xcC&WUhYd)0%Qj8w-l@%4*p)5p>N;c|$c8gP~M;ugPcTnLeX-_gHYs zB1j=|YLFg^pq~6te7qhkY=6_fN+?IUlq>H2yu=vv_|%ANfs2kKg{q#Wfq4nM91J%SpEOOB=6zoB zsC@UpgPp`6lW{ky5M;N$^!JM2j;)(p2~wlgBR$15vZL(ZtGJG3QusS-krn9PyUI|Q z1|7d%22P{yCc5sM`d%+8w@T-WlzR(AhwOed=Cg;S$5iy7QJy!+tPBj{b8~a3Pd{S@ z^4SRA)>c}Bx_Yn=A9+N8)e@tr`S9ARc{(aII%wEb)QaEV-3Let=#q0r zD^ntV^q%ztn~e&lux9~oG~$Ui7~gL(4+z#VXYwz1w|_Xce4oIZ-x#eo%u19;bnm>XMp=xOV)oiMq&)Eb(vX}7OSPS-&n32iUts^#}uIMq{c%^jtK zRfnas>#e4(oZ5uwU!Z1*Uu6uY*hD@W!x`M43UH8btqN^9Szt*KG$wrHF?Cv_tB83b zZ-ak?&sZvV3oxjS6KUnIfNTdtu}B-=je6;tVscc}I!}R=^$EFW=!pSE?vCI-$fMVT zxeXt(B21qYY6}WHM{~1$leEK~w-&#Alu&(5l@b)fA|tod5sL4+B1Ay1;`bI=lfr%j z@$t=wx{YYxZ`TPV(Fu0n-w3Tf723~;=oh?^co&F`xGht-+P9bt3hBPsb`x|xsXTOa zNL=N%I-$8K(#0XCs)+Cc-sYhza&XVnbO;pd6x?})>*sov!3F2*RU~PnuQ{Cxc&4Um zx>k(0guK$xi3YnAj-B4ou8uGVH5`!1Iuo5_3vB9uf&VmK6I*r-wnfhsFLLZ4Sf;nn zl|Rv%ywjxo!CEpC+8kCqrou=6>ZWKmd3K^tk;g5#x98~YirrX9W`lG=PJB@ymT!22 zXJici4Kd37q|dc>10?EQ1LfZ!p0#8mOGuuku)B=wYsAToHFu~A?e?7TNe>rJhhRI} zH9JP59n8vBD#PwB&zi+1D$S4-)A{u#URp-JOkgt*8Lu@DJv=+kgyu3(x$l0x2Qzf? z)#X1U!Nu*IyNUYl2}hvemgsA@>~<*zp2ZB*h8xH2+@9aQMe}UMeEh1P%AJzMrc3Kz zN}i(4Sd1;+3)_O_4OZ`0d-L>2&_G+P+BJ=5gI8JOx#ZNHFz8blPwJ>KO;Jcf$(G z%d8^V>d2w+RKbVeSjIO#)hBjgeanDv%qPDAXQg#?M0HI zjl@E{r+%$>_|>>S>RrihIYj4Ee=anvTEu`|c&$(T$B%ZCckw%h8PQDO4$yqZ+wibP zJc(~4=3{mS^&Dy9Bqjx&Z7oB!>eSh6)Lku-8$O&`*RMOTyj6^N)R?UmpmF`GbBU?F^P)S`Om9DCzT(>HkQZ-RoaiX_3k8*7%rr%C7STX7G*Eir zt3A;g(jX@ZMU$Fy`mrFtaQ9^7OPs2zOrGs(7|(0=G3>9{MjgG z$;wu8;(ets(BPsx(F>-$7vMB8jMlRNGoRC|m^Y5}-Z#%?qSxL!v}ZeYeo=`tet*t> zBD=ZU*h1wC-QY4H-tNZ*o294e;zbse-GZ!$2|D(z;+rVkAl%kt(^FE|Zs1svOs3fG zo^tQEN7;kNlR(&j#p*bcO;MxW&`|`W;^nb1)!hE=3fjvpnYXU=I#+9qkU$NUaRmhh zXjx-r`bwEie2OoUfYp8(M(M!46A$&?TWRF>T`!QvM2bsQ{jgeNF<%~^^5lsZ1GdH9 z{`K8K=f@a^10WoaYwg|Lp(-6@q&83< zUHf(H$YsTw<`LAcDx(H_V{OUxZMfS(c8_G}A<|TRrjqrkuvG3b_xo3py1rx;&_*Wf ziD5;a*i)l-F?Q7lt+v%Bdk<%^3z-1 zu=BR1k#qv~9FQRidW2QOveo`dD^(I1mr}@U{|onir{!+folN?Ti7u^D{Vk?GcaJnP zIf}K=M?`+N8r;m_F*6>A?13Q%*2>+{q&S2(iEmU>3`84@l4h~pt{H82;1&oGa9$b! zGFsGAy$u^?Z#emFr|dWv*mkr+hCf0Tdp`FLL;79LrXECux5z$MC9+tVBVXP;6&@+N zPY#1_S%pTeh)LgO(f10=Am3dGlqhYjLFqAmo0hJ_hH1NkV=g7Z-Ig?4yVbv+tC9Qg zL9+q*L~W$B(4?3oy8V<_7p3UUDf~#zyiz`&%~!`M#Ms#FoFM4wk$HpXUh&~vkgP!_ z;L;YHVHxBlmW-;z=8V7CwMqw>t%VXbpPe43xVLJ*6g*@!|A75Y@9qdv%Kn^etR>=m zSS4Z-ionEpUeUBa zZp{tkjXkYrAt~I|iBrd>Q`^G3a|1~U{T&?;iRnFi;zz#Ht%4EMfpKKu_cWVg(+qK0W+FJEh=m_FaF>fllL_Pb#>9%k-1Na*qoW^d48buV6+Ua&)kUh~nF4`?6Uk}&VLMm&om7w@}YQ($m5UW|4tiC5r+ zCDC^Hb(=d*6)rzeT7rr-D2$cCYnyWTR?@^owoC>6vuo5zuLWCj@{1T?eV=+&oiWo@ z!H>ztWn>8gn?anAA(Gv8CVN4ls%Zx&T7`EL$$Jknxoc5Lg(!@ZL|1OX_j9OIszIyb zNl8ia^C5dUCJx)y8{y#nBq0|6;g-on9!+Wye~IJj>|B(NmDFIS!uu5np8RsL9bHtd zUoi_^z-dKeMJj%roSYx&22|q>7k$L8baZq`8KUYlJBTC&c*h(AB{tKA0%h8*pWogo z@~s~|-#;*CyA85*MX<}xJ1Vx%UgZ6*vhr90JFi9jmsB}-(YU>3tUKv&91$X+PpJ1} z!`b?AeAgoI>mTPWmbCAu`Z^@uN6l>W#TwiZyr<#hjfi_oC@UW7^62eB^5o=~7gKpi z(>_Wa(qLsS@UhZ{AcKX;er>Bts~ENEfn{$<7}Q3)zOwMv*R?h4-usTzyrQD={v1I0 ziSTyl{m$-gb;FtXyecnc{rsU*1YY9c#RtHasDf&X)aup^ecHV$@{2F=4cb71W&txIwbgWyW~Z_s#l4+bO7Aoed2P*ZRIY^pZ?ub zuG^9Flo-jj1KbF&xt%gZcOq|iAZB{43&b;gtWKbF5Un+%x?=hDbrJV`V81ijy3)4o z&B4~|>35vtZYz2vKWnl(6CFAvakJdmvY7WktkK{A6)I+OCY_${@XfaRr}<^*>sOzQ zkt1bg_8%w0u->sU6d;L71(7KtE`NM(jDHhGgO*F!)n!~sil8m}Ho5uNDOw3p(S%oL&x#?&n)&bj}o)bz8G z23-7rxhFr!ZR1H&hIjW$4;Qlu#g$c%0TOKSlO*q_qs* z8}`Zd)yGXU-QrtO-LIc-+^JB+56^Z*IxX6O?(*D2f6r=dVb5{z4slJ1@FmO+d-owO zi+CqI7exyCvOt_f4XWHt3*`>(-M{+9@@<=nZk+?=qnkYP3Y%Nn%#xy2PJSn|(RL2?ZY(`rP^>mN`i!&Xo?Ff$d3S~{HkH0L~R3BS5xy$KlvmMkb z`Jh<}h=LU=H;%t-t`DEW#Uv$hjh!p+LO)gGD6HymRxh8tA?8|=m_O4-$CCsACtT-m zcB659Uo2#6VOR+7p4vB1eI@qUHoEV5knVX6Anl||Bm8x`mAdEyd!_@wGsL+n*yu4`Db=6F` z^Y%nO?fv(|G3CmK3i7(hER*gDx35_&nOQu6PGp{r%TRmrHYI=-C&*|)6hMy6 zKjgq7CNOv%cAdsy{q4L)i4_G>BGc@?-y0-~+o!zg^~B1V%HQN%S4RMWVy{njxx7RK-Fd;*-|RdA7#K(y_`2 zY*u>rNoWzve59?*!T63(hqyS;!(IZDVS7Xb%52!z`d)(UM2Ho(YB>uNm8h-lREHM{4Kf7T)8pXRF3B;QDcg{U;=k_Ha%A>6N5{jS2 zs>omSWRc#?p#DgdxpR8-3mZ_=WH}v9d~<^5PG8!c-iW^LA}I@;d#LVlz1nVYG(30O zEix8e-;j`lckc(gjmZ4-Nw9|N^lL0+Tg!un^n=5$d@fpUIiEK*rj$6Cul*bDBa z0w{yTn5V95X2_nz1EHFaBtU`Ni)|ntw5wM!t}&Gv&BP+Zar~BPY-2D7vwpfCTp=Lq z^ZCQGxR&6u=HOv{UoL^*QMg-5*RE~O!LQ_KodbId zq24si#`SFpH)qzQgi}yd&rab|pXzX_r3dO9=S}qb6w{FWwG4dj<(-Lr0G-E8yUq(E z%oR=DiS9VlOb=q+%`r@CObN6p?V9_&iSd6+q;et7_3F5vIx2fyCvl zq&sJ4rJHH58iay=qP(92Kf~o>^y6Z)p)6)T1>Q^ht<@W^-zfOF`nII zt{uyZHA{sm!w*nT_*1&`5OWTKd?hG>Rjj52REMDyk-4+m32K0Hg=cJ}KA`xX+n?)< z90pEtEdy0W5y$fw)sdn=^{7o+sEK{{1ywZjcbFZUH|B%rb7058#V}WI8BZ!i8b0q31K z>{$KU9^blS@iu|sdrk5>EfBsD1|nXpm7As_h-;yDvs_Zdz^q;R1zIlq*o)1Z1Rj*^ z=A9hHg!ozA1~VjK3xm0yd^tu^=nX=CZtDfZ_G?(fVHucv-UyGHGZCa&@rZGKB2Q1* zB0?Ve#yuEcREhVt$Y8&irPX;_@d~Q>8m)JvI{fDPm%El${-lvnxp@nX z0xf=Ojlj{KA>derobSn2=q9FQ!xN{)xAibV=c6yypI@@n*nEo1$<4I{YP3E3dttF& zFdeX-KJ6}9%$~$2@6;h|(`E;2fphjytJO(Yf`#vum8RLHRrM8NjL~!6jzxAIt&of# zJ5TG*+f5L~wo`lY+r-kj}dZ!cjgC5Jip!e*PSC*xv)h_M8V3^;@*t4%G9RQHDG z+xw}WV7jRmW#6D%x!IGejV6ka5maf{7|7(Bx@9Zf8gs<~*8q<^;ugljWVn)SHY>|+ zM`WUM_>RGI0iifX6nsJt1k6RCT;#lrS+?PF5C0X9J`r?sMhUQEbzhrF`cvl9Fr@>K%A8qGqg-&zC8N(N=feI$XxP_D6>uhZ0AXA5uQgH5e~ zGh$U0T%$NhMu>=rprb~Yl{hIM7X2n|oSqV`iY2{`jSV(tGobu~=kpg}G?3#|wPn2i z=kshKP+%rSD{msBe@A}2)?S&d^`t5comHa{=ltlAbZGl|@ky+grQrkW`)xo;_G`9m zlB>wqvViizl&0q9K2~eJR4u<7y9Gc+tl9DNd%)_CdQyh5kAxi+@=<5tdNn!V4@Jek zCyfX0+>`Q3`6Az}n7*mK_%^m41!U2rb{PW=E1gCSGH7skg3Z~4=c<;}ow*pFMo$w@ z#V{U{uoa+KCq+jqD}9>*YA?UTAVBCj#+tuQm&hST15xX+Rp9vJGO!sFg^`OWGwgwC z?YUNp)=H-~HZ?pUzPO&0IN(HrN+co6GO$OfYLp6;(&>45H9=#d4Q?B(yvq5SQk~dF zbBNk5UVE!jz)=Otn4y9U&j@|@S)|-D9dK1NkaW&q#$HG$)htUn6a%Qwz|XIphBMLV z7LANeb{EJY5ri-9;7R8D0(C7pXfCf+41Huq8mNPS8uisN`{`UDOKb_-l-8rx7O^R= zdA9~{vOV*7>%GgbMP>TI&aV4$`@VEsfjM+_g!{OeI(*`LftC3z?wG~#YzcRkGrfmL z!{eio2~txbZ^@F{^sL8^-l|6mdfbUiVttVEK0`Ql$5p@H$rSdjcDfKaW3gmMy>?&~ z*M==}=!f&E)_z*VYV@v$D>3QkcEm6aN>$zkpeQHDhHRCN+TTWB?!{QCj)XE8Uj32c zxqy7?r)l3a7WqxF%Gh?ge%t`_c5lQ2N{Ns#j#+!4%m7H{l|6d*?yl8X`K@o0jaqfV z-dmkSb6Whgz`2bOiU5Z-Vr~nS3=nW7=i7L*U}Q+FLyB{1pFct8H2`C9~LzY?EY?b76aE3ks zid6d(8kf7@#M3cwb8BdC*iAJRx$PbUC(eefE1GYM_-R(?r_&i84gjZmdh(6;M^QSC zb@Iu>i$-45G=3uwE;3kq`#6KQwiP_ppL|`tC^U8tCDPCeCLsn&e$eXA(&z*+e2_OWp7D{oTL&arZ}m@i?E) zxvuMdy~peI)DoSuCM;6QA3k@$vB#gMd88|;5IoLn*Fj<*Stqy#7O+5?O?%z<&C z%qC#87M|42`*2#?lNI0eo*MUglXO|j0o1gLyNqa&f)VjO#()+d!5d7t=1rku7Z$&Z z9d7+b3HU!1_3RA-}Vm z@uAd!7TuefdCFnsej8^rzq^{^ATwMG{0(v7!~qAR-mK^5RM4*#0_h-Tk`4p*AW;qs z!kh+mW);7neWtzC>s_`iL%%90PQTMDJN3i4aPho;0TIxW%PiRpGFtcP*A*qdk=Xaw zdh_GPUb@v;8@09drtFkPlk2Fx@9g}*cOlIFJ(ooyE3>lw);sxM(HGC{TbXMlptPY@ zEJ{@R`?+FAFj+aRSRNW%Q)0qe-1zAl7lW1o_8w29E_o#(QuFFK9fKN!*(EhP)wdLp z%Tv15LO=e}v8$tGzgo>4QK7YyoRA5n-qs9CBDdMizP+4N?*t`U(0TzHeE!UVd8;CV z+yoMF0r68D6j%7tK7FJPHf)cLEu4iwV&h8@a;>OXoL$sW_X(tHiO-9Vq7HHzA}o;$ zATPo!heU}O&0d}wtFWyq(RrKCiW15YaN`(zZp7B25XvHz@m#5!mr1&7#N0>cte+d# z97{Xat6;M6AL@X8H|-vIWfTAGbYz)TXHmySR%bF6i+Hx_ovM{5MjOs*=pc+C8<%A= zE^Yi*{v3SI-AO1^PVxFf5uzq=m2SeGCruDZ0tQ{~h^Q!9K@7RsYiP+D z;8cI|z4@~hw_!B-CM4wDh1q}PgBYos#2J5LCLDDLcWPd&;}=D@k%K1qdy)$SyMjqO%aUl47Dj97 z_VA_BPpB2F{h83v%ttJ|chxISB=t1^VfB+1{MPC#Eu7Q}wzAl3*O3Wc-d7@gPi5$U z@~H0_QaoCZ>2KOD6O$1RC+4qZ!$W;Of<$EB#{yf4HZZ(>^0=)rB8Ow}djQzlQLE_6 zFdwo8-%niDXPKVReme0w?6WFPX-nhaaP4SYWH<3VpDePAi`KX}DD%GpjVGynzf0cq zjr0Gg_yvR_?S~A}F4sr;e{cG3-$YRyo;Jmrd5#;r@G2e*Aq zQn{bJU|dBxO3`1jJrSGjFC*M;^i+(H%Y>&wILB!5zR#qv2RYbS-3{}aato%}^NWIm zPj&i0gB+pSm_Ne>qzmjDbb=nAiytnaN}xYEEsd4gmoR@-?Heao4R#vyC&!+G)U7_M zt(2vK+AieSujOkymMpGZ@^4Hf`b$z{96m3YSJ5_}!uG#j?dsz5+pjHW!UwZI+Ifcf zOuWQrYG_X=<*cAO-0T;G&OWG1{*7szR%#xKcg!}P=^>&!tFp4tu7W>jN$#~1a+1^5dOUI=2 zE4U#itZReBgFZ9xZ_aLb=rBnl(P&e8!}BcE7c*=x4;(!-SEeZNFmh1V4~%W@*pvshJzT^ZZ=paHq!iO14b0JwZ)hQ&nYq8zjdUVhhIC z;&#IOY_`d?-zB@EyoED6EcNJks>5nEBOw&Ov89T(Ca<2p@TUtO8)({Y1z6GL9V*5C_| z3SLaHD26~$Xje#gjAD5*A53$W4lginU?Yeocp?)*-DyXwX# z;nZGtt%LA1$w>Cgx-Y@xcJDYeVg4Gena{Twru~PYO*P$rSMZoA8rzkLZ)0JIlbKZX zsui*rJi)?3-k>&(Np5_E6tYPTu<1Q?K%KE#L)T}juA3+#bf$)h&NsQgQZ@}4Y62@9 ztcKO6(zgV!@WryRwDW>}u-u;1z^I!{vrlGX=xwkX(PoIZr2N*BtN-d($>>ngQJU8( z7n|_19ve(_y&70dx!a_NRlYhNmun!zwqBp!bF+a|TuZ<7P;!7)JoPBJrdOKBn#2EBYe z%l{Q+??hgcPjZg$Hl;ViS#3~dRa{19Q3Ar?rXU;Iy_+eDWY!_@6(Q{hXT zQ8li(IM(QQgV~VshUB3-`uz?W6$5U;gx~*+ytPuwXgOZ***cn{+&s?y;~rzf?asBE zQOfU*EZPsl7a0PMJkORgRYqX$Vj@=mMqs%a?ivYJ)!MjTIs7ZS|$xetLg;^kOAwMo9bCeM!I3M`hW|^k)VEJoC@d8*f zKb5rXGS*FtP3A5lUw=LNa+cRH>>p@7C;YGp1@pgXq6R9`bxf4|Cn-iQ&Y@N6YFEH&U0ukop^#OFCVOFt=PD#|*JLI4P#|Mx_Mtd# z8xd6;pQ$F4{ca9>8J{0&s)u_D6c=_m+h>-NSt&~KgdLusvvhlj_4x$T4?Raw4gM7W zrv4&b7mshQ?zot+@atJQ8lOLtaHXj8JIl)EH5P{SMY3S^<(h}3)rWOG86BN(HCZ3t zje35T*W_p2*EFSEDoMJ{QVq$0=y|)KzQdhW^^dgIlIu#=6X&@F%(-m0gz8mD)NcRT z-z-2ND@B&bSG<^N!{()fkgHMNX%SVzGu7Q?&U{Ln-Dd66yW<+2-Q-W1N2~IFX0i|{n0uRL zwz|C;;~J4O%NzpoOwnPr7k$VZXp+Z4MR>{Q@zV+D1EclH8M(#8SEDKhg-RvH89jcY z>mS*yU-clj-aNkP&*Ca0@Skh>Y=M)pfS7E0%CGCu$j7x*kG4gBzr=*5Tq`ypHDC^y zH?Jj0-`&X%O^M*IGG(q{wGSC}x#^vsb|S$&{qQG7hfPtr47ovcN3Suzn4odd3qgjd zT)HMR!jmkGY_zthvL}>rIrba1Oa}FN^d72UV^0@T8{MPYDSz^j;_`a-*TOTaQM+ zp9ry~36c?1Y^yjVoTD%D32iaq>Tf=(D#DucNXhE*1TlNl51PPKdf{89RsUt>3&+Y~ zzG)6%l%xpK=dn3%lc!vHBm_CWuHii3r#y(h`0fBgI)a%XqhD*5M5A%2(`&T(K410O z+UWHcy)Q2W_b^?jTYZui|7nyel8vdiy@VyiYT`wcor~ac=toq#GcKN|?XX=i9OxHU zh(`AAI?_23!I@+pP199B==ByHDqCA;%S@>iliXg>Gi_>!l$Y~g&yLIcvHH;8``r+TL`#K1WLfP#&8cd&rYsHo96#SF&3rNoMxtL+Jdkq6x z(K^Yole24F0sM1Wu28y|og1%rrh6Xv>=Jw+W!W2fr#`0qbA7@y!Vt}cC5DFkk!4W; zo7t&iy_B!PKWd5NL^*oOlvNXI88DZRcGn8QzC=4MIkx1dPK&IpuzKZL4D%iDu>^O) z&yz+egX~+tLWm=OvvI-62u&KaLN$Px>nu0ORrD-H@5_)g*GE+0KO?TLr!sq?026Vw zhwE<;z|9K4YEV9TdwqQdmB?aJWlP|XxPRa9YrRML^}&p>N(%`hGisqOq${E5&_>v+V8zj4K+zGcXOPHWDKzr!zDBjX(gWHCM@JM!efSUSH^82nRvW{Xob)Th_%FH2qh{{38vP2sp307beZk^eM+ zF){!__#MwB{GUw2X;w*NhkPwv-9pHXJG^(GPr?rpRL^gX>e8N*m^246F0U!g3u9kj z-7*2GF9kP+*t4ro*@>ht-obad(}&_=8FLJd6Hr|nX~cnP-&I>FkG3}vJv_olZldhM zd(3awszBIQR48IM@(V6(7rvagAtd2hw>|TLo)%NmvILWaJRR~)Fdxn zOxx*{8cYegPCm&_WOHQ(CY>(2OYl>PgElyuim&njk$R(lIY}}}DoJ|3i}1F%q$Cz| zGD(4{6;%5%G{Ro>RNUJz-63br!t`OK)NdwWl8Hz$Q^$F=2r87`HC2kuuB8ffXrCl} zn!VlcC{GSzvkm5OSMOy_AD?Or04?gR(328C@A%9}e1<}MLo7>@hO0b{r-sNVD_oUm@r?1jT=B zCR>eq?CN6|A9{z#`6+C;3r$S*VXfx_1^37evp5=%z407{$L)=6NDHOx8!W{CTJP4z zhg!o&45;fk$S?4!I$E1H*pDb8w2qk}&BJC|+)M<_#+sPfZ-z>=dqMv_F~r1wkN+J} zZus|j&(5ekRJ{p#^K1PB=#E@UcFO9;BjF%3-ng;eb?87J7<_>z&x805UU*_3HktJR zC?&6EZC>BF!F+cy-CRxFekFQz;_rT70Q_EGXa5NJA5oXcf^ zIvLfugV@)wp)B427)*>_c7BOQRHAn~Cmzd#tUH{`ZV>5unz(#9=(OvqV{Ten(GPxg zeX|N9aqsW&{|y3*4C{BFa53;c=_C{TKI4q1$=-gGKrm>)ay&RU106p0;tx~z_H0g1 z$27eH7qs%29{&3if1!r#7cW^z#~B4lYd04`GD4Zz$hiG`(fFz7Q2NWW5_YP|x z{3_DlBl;A9;xtn0_MviJ^!$09u@f3=_Qe1Ny8KYshD9>Y8x^svW~HU9OakWgkelUo z6tBe(osO%nD3PoIX_0#Ml~yuVJgL4CLellun^)sQYlxgi28jr{4xnLC0AZ|s$zRvN zptQzimQ?~ z{6WM(#s$G9k=yh@ebk%2w|B;4y7im2$=?=wfTAS5K>PoiDo#{9jLrZGAN&Yg%_zj# zdxwx+v^{4{djPP~Q9`*m#ir?wK;!f$yAVCV z<`l!I!WdwWY&K87o*Hk0`!H6nSw|a#^?jMI0*!Yep)NU+O_Eo*G`vTpWZm2S~hQu{Z3DjLCg7SDdg6L4b|X}4}Ir*=(o1Es_W%j zr)c27E+n_XO}FW720mcsKHe1?AM1_>WCvQ%zj5ilqR^1Gqsn%w;e>|rZ{mLbz}hat zBuhqYDP~PEXDNeXN6Z5kb3i_(98cq_YM&m9ceXdxUB3N<*OpO$ao9P;@AnByd)$!#eY|z|c+IK-WzWm=7Ny?##Ax|ula zkgbGYSH|83BuoMDI5VnL+T^H=0o0_#sf2GH`8vP%IePXD)tqw-Oiv7O%-qnNIiWeI zqcsEVG*`g5%=Mi#)81EtAxqqO(=ugngJ#h6FGwQI2qN}?$247gC|^SdhwAt1_TD{E zLSp-1Dq|UVa{I@}IH{YTF~I?|A8tIW!sFsFcEd;+D#!`eWDG6^ZZj3Rf||Ty^6J*V zRJqm3hMJ8+vPa=>41uSrK10^W!7HMq67K-mEh7<9zr(o*WZu>YJlQ)F5o? z?Jujw9>_U}V$!r~3N_M1rF2lzofRDgXJ9d}3{xoX^S+7zlwi`USFdcZ;V2zG=E&YI z;U~C0fT#H{&d8kn?`iC+`erj1mq!XKCfUqdf$~5*TI)Q-DR+A*egv=5HKR-Pe>z;1 zNGmGhJ#SQSUmbtHgbVCdNT9o!ZPVVM^4TZ+)Zt~1c?A-jTWok3xt82&JviI@{e8u| zeaDYX_Q#_1#jFj9>=qUl4V)e4m=+A=NUNNUAtP2vd~uyTn-=2Totzg$PwD()s94MK z-v!N|8#dipLz1#Hp{NWpMY$h-Y}@ETF1OuH3R{HyUQ+fs^ph914pzU^lo}0|7_&sr zgr@N6#s)#AU`t+y@gl=gG|fzFJrMIV)yo*fy8_=|3(`gEV zRZ8srglrmpbN0esV1r4)5iTrrt3++W)i>mp`v)T(#H?=gWMoOSZLZ=ySJ*+@J*RKI z2e**3*d3ym`jb^EqIWLqcZ4&r*c@WSd?8a zQsx}2))~=T8{fhg^@jfVdeKDPmeE3`sKY%n-0_0pJ`*XE@pj(qd*IWoQ$#w0;QtgR06vCcJvgm}J$D za=s-kjL<_Fghm9JoAoKGo^Y|kj#$$ZSEVg3R^+w0mzyecZN_f|I6@jxiEh{-L%YBC zUZ$-)PQfehZ8Ee)VYYi%z}HGH*&500g@@cg&XvKtGM4V&{%OfO&)CT2^WLT|n4jB; zB0(0#im|i(*_c&jjwb}RBU=qk-X<%rESpPA`C z+vn71lZ3dtJQF}J8?aLRklevBAdDm-)L}DJLSc)-@-Z)G!WJ#NP##SB#hUe_A!XUr zj*E*sF&>nCHt$Lij4m7me1gyL%y7bX^4(DA=%w#{ja=skSAMwWyydGm0NL0`4g1HP zFoMe_ZicTzVBA`(TFl`fbs^?&l+HBzUoZsgo-5z_Lx)WZH6EB!w;K67O=hvcx0aU2 z_XhZ|!IiK77p&Iz7AwtZ+=MyF6}N+I%cWCHDB zV%h8P4(-M~9`TV}L!PR#kWfY6wFP5;s#{aOXYOgp@e(CgqjpaDFvf=p{>O={i8*}J zM_&(spKNg>&eH|XQK@NZYB0Mf{GJn}|MADU$M@`|QKvvfmkPQ1|A&d2c?%~+kN1l0 zE6*pCnx#YB`9InI^gNdG(An^J^tK^>1Hq#y>cVIR`zH9Vc41RcrRSntT%%N+qNM+<+2D;QFYY zlcQV(;dOj?1`i#ce(X7ha~}a8ckYn^)0wdYz&451xqax-*6_%3cn)lXo_G#bfSU-k z?9av3H&E~WFohLXcqBZ|SZdj*fYWHpC+;`?KLef&x-%f*a|I;d#eglR>nEGER*8@s zBHWp%!}SAng?dgbT9qC?T`Nt?>Dcg^q&GS~G>DkY_AhbtDgc&&HMt7<_mT%kNAh1z z;G`xYQT+gSX1P;;$~I)jO+&Y>A~TL`yFsfk+xj#&H=Efo0iJ#_(`_#H)l^^{yATdj z879N&cm0v%jf{;W)dKtd5a6{h5pT{YrhhsAaY(U%TyLhi^J_|vc)7A&uuOA)mo4DV z*pyOCXH!)VZ5N8~;*MNaX{qgG0U(lemNLNx4L@{Fo-3)Wb}< zT20)r{Q=3N%S_#U$vkd=7+zRiU48HUsPO=D>17xV6M@TyosslzAew=0Cm?%i;d{SH zIKS5q-1`5{rOOjpj(OWR_!fRys_8~ZV`I7o;A%a=H;xhr+Wkos;~sv`4Izp@i79y{ z9u%>@3sw75o{?~1v-D(?grPuQAo6Ta3&jHcedMv^4P@37BnD-GJt8*r-KUymB|RrI z!@|QoB0HeHKvcC88TT}djHOT!#2gXgVZ)?a0wb&@=yq>>C<>l(A+O8h6sblV8|>%h z*c$hxN4x9abgO#Srr+dxq2eaGh-nd^U*eFHz|=zC;PzoUJUED1zJ^Alfx&*3k&*GT z(IJsCXKBg#4O$E-e#dUbBqSuyiGL#?HFX8FeQ8Na%23G^1IwZD?j}}UL&Mn75eu*& zgCxBq{UifTE9^I~frDf=48$c8A^rK`fxU=+@qY+S~>YAD#K(Dtv zQFnt_1zcth#OWySEC1nA!e0S}9?$#IG#Y#;W20LeZr^#A|> literal 0 HcmV?d00001 diff --git a/egs/librispeech/WSASR/figures/otc_g.png b/egs/librispeech/WSASR/figures/otc_g.png new file mode 100644 index 0000000000000000000000000000000000000000..ebad4918023445f482fce41aec6fe44447c9d978 GIT binary patch literal 33339 zcmeEu1wdBYwl313bcl3!HxfVH9nv7(-AG6y4T6X$jWkF|NJ)1o-AH#!{lHr&?Cn1L z+;h&m@80|F-S2I;u-062%@||MZ{+fYqPzqWA}%5X1O$?lq^L3k1Y{m?9Rv>x{B$2P z$OL}fb5fQNhA8gGUx$ETo^%$|aJF$bwXimUpkNdEc}2m>Y;Nb|Ou;5f!OCi2Z_j9K zVPNKHVB^GSYvK$v0ncshj4ezpOpJf_VP$4zV_@cCVC7I{;ig~{0=lpXF|#wWbFpgu z>~COZVtYHFoSmnIwY32St2hTEGcXjjoPm*rt+SnzIR%?Aa4%)+Y+?=k2AYAN3M#-4 z4d5>evpx%lJ{JveDP(VNZK7^sC}RPvMx29?mX-q6C##lqOc z>2|cAP0n_9*3K68e{3|ev$Zuby5;Zfhz5?1c5Z*{W@=}BySv*y_Q1yfF#1h{ijjf! zpIg-}jGX}ixsgZ~Zsyzd-E@~Uu`n~gotoqBfsMgk=i8o6<_5-gZa?4u(ex`c(ur%__#=!2leP8;&|Td)yJ%#2BFU?X8`ov$>s_ovnej*zc{vj&?4##y`$DAi2W7 z_mQ=;vj<4Y0yJ8hI6HgX3XXw`vmMZ6?rdXy`_#nU!dc^&Yfa#qnUMqNbn`^?Mow== z{&907qCakIogF=HUi$6km$z=7{r=jIySw*H?QETI7XmPeP1woV(ay@`PHb-0`m;sE z&f3lq7@mTajrqraz>}Y}u{Zdg{IJi5+(q60b2(c=xAbX;B4Xg>nNgm2G%aOO9X)a$4Xf5&gS2Upsax*fLebL8v|<#vm0U})#;Y-a;Ja5QnU@H8;|A<{QYx3>c{$q!<12vcx~0xf?l7Jp{--`&UU zak-h;GKaT1@rru2-V0i!yf8pfqund4MVrS*}Irxn*16OQ6@BU?8 z|0}cqz6Qk`A^{7!lNviyQ>Qz<{?BFRZk!*I^V@d*@e0@gmHZzoXSe6>7diXEhCc$F zwKbrq05KGHGq-RyQL#7pag^NvT?w@P0#gP~_CK|ksfGIu-2F*t0Q1QD7u+y5FyS^e z0{Z<4Q8q+)x>Y{^R7u^!&W&O4U!ca* z>y{V4$od@@f3eK|;=slFmq9M>pE&;uH*fqnK-J$kd_OM%hu}A4{u|Th|D)IsckU+* z|0z1L|LpgdMyEex-K~|#ar5NHsJ~P8cSq<>i~sagZV~cdv;OZU`@d_+-71}5j?gVc z{(4yc;(*A0$IUzL-uX)Z<4n2RL;4FhZ~6QG5H~shGTglNO8+0{<{y1?R_5FLU(P1i z-*L0<+U@}QztqiQ7S&J_H?o%D5MyB0kT%!UR99EJ3jzMZO$sq_3SkinZopc(u~H~R z#3+P?|BgNSdrVOD_bA}6kwPOEN7o;L!kakT?_K`41BEv+oSSH%Cg6YEUH=~MyNMaz zAmq)Hzx9s(nAP#^<~`z3Ga4%N56khfbL(XA1lkP!F4y? z{{&O=?*iPP<4AvHfMffSzW`wC&j9!HJp7&L<6j8K*#2}(S^sAN87s$sHj@3BM1LtH z`!fOmI#BlKz|#MRvOm@r{Dq*5<9|AovHzz-*`G=Dmx8iC6Y#GCWh}Q2{ReTp z|8O+@3)z%^N-5nXCGJkZe{b|-`%gx{KauCJgnoY_-rw^6zY+TV&>#Os$@Oj7;4j7P zzs6C2JiRNY|M}qdg`2M#enx`f%nlzwKoCPni3+K@>uu}8X{bs()3f3(g zkE}Oj$2HpJ7#fchH$u-G8qW|1enc|h2Gny!ASo#-iRO_iXGuw*UhPI*lbRUZ@7a0Q z(|dB(R()u?^VY*`b?Y6kn{!!td3n=)iRWaW(dOapp}8UP??pjK2nxcl*~ikRP0YXl zt?g$QX6S0kWhTf!zVLh=b)HnsUZol7+wWuC^heE-)cxd46A=-?Y}SvT$A*cE8wPV9 zJ54+i7l*?V({i+s=0vu=ql4`6V_alp&bvcr5>}qc`pFQV>lqLfv85tNraoHGd z>Wn2Tb=%deiF|f>&at(%#cVy!yfxRTU%dME-Mb|9c)rS|nm+wCwtTYfxii;B9n?c< zT5*6hZO;!@DiDFui{8i!;6w?QqEc5%_VxA}Yp*SFP$)@kivo6L)EtPQ9Ql5LO8#wV6Mzj#ITIy(5heSKJAMhUGYlT%ZRBL(UjwXO`N11WuC`O0PGHYM5??23AMO(sR_k&msp;TP;4ub&(FYU#wdP$Z>2(3!1tdJ^zq^_<` z$x737)&Gu{#cr05QiDmq;eofecV4Enl$2j-Y3a0w8cl;5dO0pWey?q`)<;XMm4Va& zWS!V?hC+LVd`~_4Ss5WfK2Noq(5Uf&8LnJ!P~{9wF-g)Je9D5$Vc9y_=)O-H&Y%7A zIDRBqbu*p+&y{E0M zo!19|)qUM74=Li$XG-X~2b>>oHe!<`C4e0D-126gY=#pox+l*o7r3m)V~k7`6;Z>n zI9ON=Lh{LlSx=ZH?c?^cK1D@Eu}V;PMM^sqVp0;*#BD*yyrOS9ne!b0&NR6Pj0?%< zvvkHR50K#vdPJ1B+pNIH^Fh*L4v;`DW8eb&F)+4pDv2$TM4CIixlR`LfY%lK-yk?A>1VNg&VE`V zj6XbxB?Fo}sbZ1`{-7T4&;Y#0GfRkm@vGi=-g(o!gwffz`l}j9i-fn2&5tba?r8nP zAWJ|q?E%8liqoCixltK|;(&%3Y1dpf`hzYZb_ANS&Y&F+d48&>pX$h*3{oe{apQydPotc9pC_9@1K-~GYHFzp2Dt7Zh2)SfVn307bvNQ(xIz;&AsfiCE zQw7@Se0@a&gGnh#ArVSTC8eaBYjwbbBV!@yr7W_tat5Xr2r|69s;e92ZzwGu;v2w5>!k~Xb=d*=XnGVD9sx*dF-|mp;zyj z?y@;L33}!gs1*e9qTkBaw%TO}pRYtmi1TwR#f#qjkCsiNP4}=lIdjbo?ipcVtSx_& z5chnHQT$?mKQuB-c(_2_od7fNDF}25@(j!Y#-6TZcms@G35-p$8uFx`H`{>O<8=S} zScJrfoQjNvu~^ZEyUHdU`z=xSg4#DDOf*PKv^=++=_UXH_Qv#m(`6DF?Oj~N0OV9E zHZkJF-X|Bzc@|DW@ac_BLcz2Gu+;(3Z1j3{IA7n(eI5Tm)9%vuhGV?WgY6ZA zcJM^GnN=_LEQ`c*RAQ!;bD;CpyikOG9Zmp8CrX&ZOqD58l=W;x$t#8ydjNprEO7vI zKvIC%O(7#DhE`=%wwtYc0oXH{y0xwv0D{x1<`Rc42>5_;=DZG(dN2iN)t)_jCN$m_ zNv#XdSLcBNQb>tt-<~vR&sSr{h6jO(K<##OUQq8?O+J`!KGAUSi7$SrLY%n_M(q@} zXXA#16cWrPO)!-2h-!M`IF5QCqB3Ra|8=LZpaPc=Z5mWIBc+> zt?3s6nc{f`+Ck?NrPhz4Y^H)!QShG%zX90t3AJ;fqgZ#1Z16bXP(;@bYBz*z@7}$H z7pmw!psQ1|C_?@)PMb-1Tz&N1#rDV+U}Hl?+FU|Y-H~6Bp|55tk+glmXmrtpLFJ#% z9Ckk61GC%WkG|u3@#-znuD|6~F zmy6xTYgRj^>}JcMj37gl!%6&mFsNzcU-ixUn%kXL=gHB1bS_Wyy+DrSSJj?a$2=g+ z*EV`;N&uK(ibagA!p9TdSO)xjd>Zm!tj3Ddy)V9J2uC%eKPlFE0T`6_39A}ZJI=9C1w(y)@GvDlCI+#!8%gI`89_ zbf8g-!DDJ?DYIVXspOP-Wf{^RuJnnxyL0y}bOYyM1pa~DB&NRXyD%Tp<&I#FLsSC7 zsO+~y`92)(&u$TA9v(ADuvC^)dcE7vi)E$>W0qVc=lyeQUd~p{14I zGK#*r2tGTcUjXTr-Yf_U{?x<6z%fVu^rY?n>lbf{ss>U!HrBs7L>!7>mB89TL9D|0 zBHduN2dbi07BV3p4kIIDFH|0q$^2)(weeE2hO2Wc8ysY!fMB#!*S)91ZVpZG3=17~ zeo#YoDpZkrd}#|b%QFS2elYr8ey|MMjy}1lDzlgpf_P|((pjFKo*BN^SMJ9f38PQH z?d(KJy`)?KPLviKr2uLl@%-lc3!cl1#R%;EMZa3!SAgVaHZ;7smvEBL0nc}Q)H*#N z(2_K$So8>EHDpwlf++MeDlqDoNM*pM0W7Zc&%r8uPSiD^wh6$gsB%Zz&ynGPbk2+~ z3V>?2cm!&%vI=1VyrTL7tGpC_fzxj!dxu$4xm#P8zIMTpp+aL~=RG^wfuU&50Zl$3B_*9- zS_%vdgto!MC5nE7%3bfqM2&WZ%djLND+&!eTpYP+H=M4#)?;Ae!K7 zWxC&?$maWa6iJAi^}-?HLIc)g1R=i`#7e44c^jFA(IRVVhN#EG7WuzSXwWf<_OC_+l8Kr(t`tRul9CT91cW+M+01zZs;G<0}l8m@Lc5c`P zl6`n&-}~WV!MVmy2;XBxJk^?oSdKh0FiwN>u6LwXFMiOywnj|Hi%NCwmMbUD;&$ZS>^ZaCzP|7>N zw#zz8M!&(JC)Q3z&WF9+y_TkFLsyY&g(RaQ5;T<)1BLH@8hxg+0;?_D@ zOSi9kH-0LlU74fb5QwlUvF<&QfOF>75}fW*qPNE8m?f;>dc+}D7gTu~j{W-#FAn?WIGQ&PXvu@aod5!}fTvOp?=8c( z{HHdWfIy$p>7&C_x1c<3!FcqtwST9TGK-Xi2Z<&xE?|Avi9nXA2l+kTqmA^8(tEwM zUfguXHcBb)(BjlEht$g^yRU1eNsj>+bj$d9rq(U{yTghrGowd?Xn}Z?xpIfR67TSn zxlK(RThIe>x5WIlk#3e8ee^k3SE8*H9w$nC0)h$OyrnByo8W^e(V)+Zyl&%NSNvMt z(Od=wAHX2cX#d+6-6Q!1YY~sEqa1xH^Q9Ks^O5`{s_e>NUe@kT4|N}tZH&-qOiY!R zV$FAKj(sT9(X`PxjH*DnCw8P}F3RVx-|@{vlrK)@i-B$z^sD#F22Y0Iv3l@6k1L|U zdM>AF7t{B=6@3)X$~spbA>Ea??w%YUTwRstj*%D``dZZWb$xy$4gPkoQrl#Or2YM` z>TOAAHDVip6i|5}YJ8xc0!O%qj`@5#x+Lr_%*75nN6 zS>zLUy1x*%j*ia!{QPs1o>xtMeMo>ivIYb)8ZLHtV;)Ylv_Mf(QV=N=j5QTgDI*S_OtYcM7cozA(ZNn>&uW`x1Hhc8xu2-Df_eBD+&D}5(m3RW=hh{kq<)e zET7~WYq?V!IfUR7r1m0bfe=EU2NXTR9jl}}V)nB29{J>VaIR<@`^lgM78(}T|4OE9 zTb!5>rZvq;hDsxuE;X^DIt}Z@<2=E>Z4t(DV$I0f_aUV;5)xAEyLXWl73`E+b9Ejq z^|@^@T0NSi`0<%63_9T zN0+jlvU_&9ibW@~C!W96dQ^bUvdPQ$Ad$L8DYPEgdbk@Y9r~_R|ctR6TN<_OV{r685P=MR_Q*G{%89ov(NmtPLdE2)M~_NYtp{5rX>&# zPcs=Q8$xOgm-HjDf*i8ikdX8BkElG96C)L4=JQUet>)d9Nf}Wm+Ld8!YRnyn~bm70uxIIyceB{!TkoAbxhWApJvVX>_-LWuE zox_)B`wDu*y1OU4#5inNTVSDj=?oF7_OVqP35|!DzrXi&%@OL%o2g9hb&REL1^VMx zjJk+GEEI5GB;I#94>D3x2J1!$rdZg}$5So=RwbuA1JO%FAn-Q78G@)sR%GLUG-9F! z`ovSdbY9bQrtL%DgNmS9ijAHR{F^a@^>W*MB8sB=STxEandqgK*WM3yCNzdEoa5g; zb<`#j*o{1ENzSZB0L(ri$$>v+D3o(8%mz~00Pish2qBd}e)8lA4!1oqkYYgL=bun7 zM6l=;n-G5SN`3o;egu4aQODJ=07mVhD|o`o13EoVAAKlq{7&j{3I*-w4Qn8IU0Vs;~VF}b2#|0bS*@E+*2dR4rCRNqs!YT>S zlgdr|nv>Yf1`E`g@)l#@HP1gNWs2uMDZ(w*sqt@Z74_P0dpO?c3+hh57^oCedf9fi zzaad;W*$y=JT@um#f=e40bb%o2I9thXF?uL>*8XL`#CB)M82_omZxK+4RE^g0>SQGB^xe=OJttAi3a*!Go)kf0 zZvMz48*yZETymWHRkz|o-hZpE<=kf;#v2fV2aAZ)NUQBJinn!r@U zh#<$y%liP2Gr~h8I{`4anN81Q|9`q1o0%hr} zhG!>Hx+d~bs0X>&PjWATEJj&$SCVjw>f~}}nZ;Mij|gfAAn{L}W^m!;K?iAse49Z{ z{J7!YFdH-GI^y=a((B(Wb!E~HsUHSQ>4UPuRJ+GaOtt*-^z;OX1;>2b+Nh}G=Vai= z#l*$Swxe@%sVNC7HZmb(*V zye!TaPmNwDKcP=6>HMnEp;X&KXYuS>(;gGn#i>8QNoiB!UhFwJHSDzxO6#B=A1vGT zl;_?S)B|wV;|=w4>ivhBY6OE%U~L>?K4&qJhQVdqoZQ&=fJdTH=gzVwBlBb(NN%kO z4o_9t-H+5gUryBD+pd~3?24l}y}H;vWS2@&laZ16Zifj+F$$&H8f`QN#S(a!`I(MP z+-K=w9+0PLjRswX0z2~K1G1tO2KW|(f(@r1zBDWXd6HAEJjGL?(wqv|CC0GO;p%E` z@cCxRmIHdwlDp#j!AC)45>ZPYs*hr6H5XF^ytUT)i3yh8aoXT&pUA&^oRFgcxb~G_ zbCK;V4Zvvy@|lY-hicz>@g{fAGsK(qPwxd~Cpq}h>E2%i$na^$&TiQ8gSd?ZU3!ER zv+F6~S%ApV*&5C1agvQ63EwZy$)V!aRA2$ck-iYa1^MuG%4n&kGf_A>IUQ|JJp`O- z>WX}FMoHWWAgSUQnD@!2w7Xe6X@CcWC=uyH0tZCK&ctX5h-0@NFX?qM>D6`Cag<40 z%sI@1sliBS-RD9U_twLlXvwD~scA`-X}k)79Q4+idWl{sGqVZeet18ZB9Hw%1aEvn zuB3_jN5F4ZQBgTPJHxUJ7VWg1`Gob7Qm$Ji8k*}}I`NR#>7Ic}FzE@~Ci1h)S4Mi? zScx%VmpQve#&awFpEgu?glIHkixp>{eZe5pO>P=)WnNux9*yQB9pU1B<)H+3Vx94{ zDm&+V!AY+-S!--aM>T*F1ts=1WY4-xuqd6kR417-MJXSU7jAydVuaXhD-nEG3qas@ zeRVm3jm|JUY`YWp>J_X=3knl?5OE9+s`dgnhTkY;V8~VkX=q=hY(`;s8Az1&@a1-O zb>VQ?;@@{jhr(7!;R7H%gZ<__bhe-v6cHVHH<%3oE-`Hp#gVQ;P~|2KVMTQiGJot# z5v5qq)qa8s+nG1p(?iI52H8N6D2?b1KsJ;8+?xAEn(Ty+a~5`X^rq7>&%s~kH7f5X z>gy{Bs2+xin=3U+zc^;72;M!bgMM)cK2u3**w7$wArROo-~&Rz6xQ^8*|RmY6+)KX zghbcyqGBB{C^-|LX|l(<2z=x5cs)cn=9>a0gc)RXzN`-daoqqLn4ywy(?QBUlHa^L z3uxQfRo@boX;N|eNQ#rA<=NWWu8o!G_X|2TRyk8=iiKkb`72Jk;d6spJ5WqS^S6{} z^+YsdKiMw`7iyIH6TsWf)_J58sA$nmlI(x~ZW{8;DRFbGgl$0@BpQCuPAiaKry#i~ zA?gg-4?cOYcwUQHZK$U<%gbA`=6Q`C2|CXARl}h}>%mUtOH+mcRU7CTb}yjo710HBcgZP9d`Z zQxPQU@S!GrX$jWk=({u2}+pT%LIXNJ>rlCQItIG-M^Iwu!1H2-| zTG!dB=(Rv!AP}cmqQ_sUt&C1V%LrqI;(;}&9=|r*-V7Znvq0kVm|l$P>ToXicokG7 zM0H7VPAie6JX;)k`k=y2tDbyCvS6}qpBR<%{U`r{$RwXu=FK^!G81eGitTD2>~G z+gYPz_H_>T@HA*;ZI0Rnm%%pP(B|{=d&+*)+q}qciJBKWinOax@$iy6h$j)NN=uar zB(tHkfVzhMuT+<);H8%eh~w!IlDZ9EJlG6c#Ikrxjt zlX@8qC$Zk~j=W9&YQPbs|K{d!#xI7Tjz-5FWzqVFu>!at)8!%TuQU26MYrM;G38y*0ukk?2{LHAxis^;tZ;CFtsgi5M) zM3u)6NVnf3pGWxkePMdi?MrKyJ_sRw7F+uOpBvTp@;GsLQ+3%!UFA5G$be<{zWUOfrfRM(JBc+BMw*c;N#P)kjO{@{q^GC1#Rr=(4dQg#4sLohn!jhAc^1$zrCeL%2pg!q_ zeL0I~V$z8&FFacT#{?-h4FohBF5sQwvIz`4i+i&FYWOe^$cfoGcgEPJlk9>#)N_|W z$Ej2KKkSP1Q(oHmW06O(`T2Z#HA*U~S<2iKK-RVo{ZPf<$0N+Ru!2pPN6E#-MUxYk z_Tu|?ODG!Il7;xf<1eQ|q`iXrRjS$VXdiG(OC)?}WQ%?oKN@Z3SHg1Xw2G24oP<|w z^kuauU3+Iy=wXmqMm+x><0Ak4b6rrq2o0XcI^It2+%${DQ@VZ>nD;5CrI+Ibg%}7R z$*h~2or01ELpVgIhZYXjH6W%RhYk^S@mejY%tw;(DM=}($ANJTtgsnwzS;$-C=$bXvx{yQ&UM z=34gYv#ZV1v~miQqL=R0F`o&!2L(en4|ZzA;HS4zZribW`*Wu&l5 zce&vch@6R+M|z0syz!!Gkw+9uGc={-(!XpU9E(`2g&h-b(*u=RiMqRTaNWjj<=8pp zF`>{~wR`vh&%7d>6y&`-Y+sXhq`j9m6^Dg{N-1}KR_UsimXpY}o$_gW(({P3)~~x$ zkSWKi>pb1*aDQ`4!f$`=%NMB`SFFjo27=^y*Ci&t@86AG@<;5Oa}mBWdu^%vm&QJ* z@jBtbRpSy%hOP833&+cR3~hTBC}IxS;X5Yw5T7dZBlWyN9l;=!JljE9%NM5$Iy0C> zKx8K{ckc}vE@;-w)Nonh&|&GbJUNwjU222D;)+J858T_0;fD`Gy)IFERpVciR1tcUa7v}3Ya3-JU=OB0PrgZJM&oc zC%m`x%Y4?rPJJL=vReE%qaW3|^pr1nw_S0{-#;)*{b+rNz?*A-(O9r`?W?h$oEEqN zHK#r}($)H0L15abH~c$^{gx+0U@Hnja|O=8J5Jozhv{L%7Yh19CDxR%&Oq^wfTQR$e2dGLm2t@SM{l4c?I+d zT$PC8qhzXd9#xp{?WZvm5uKZ`r5}b&En36z(hcVtrZYZ6+@_-a#^G5}FyM2~ZsL&X z)6rVb#+BPPX*<&T@xtyblYxnf=3-^sy?f~_@Cf1PeOP(^5OT)i*8DqA>3acbWJ;cw z2bz_8?V=Ja>_N23P{NsL=--16MT~drNLLVF5E-O5tqo^TvR!v3Y1`YLlN2Ep(Om-0 zpVuQ6TiOQw2BBDzR~otvKJhpe{uZx`gGB+OA16K5w;a4U zT!}5mtOC8SbprDy;F1|r99o7Qh#WigJ}rjiM{T&#p2CWevhxwnOnKEI$pUOF9A8bE zY0#9@NV2**I=<3PE?t^hUmO3(Re9+`Ish$?yH8WB`#r<2*t7Lsz?8X1k)!*myHs+Y2u~CAq25%tn|Hou6jsjU+&eAjaiN3 zN^(jcRHjg(OSs6-`H{7U5XVy+6yp`HG}>tjdLC(dsQ1?+ONjhm2OB;TdU#Lv-Q$mg zoIYng7~9Ds-I*m(I_{n5r|Ol3iiHhW-tHp`_dj?a;?nEZ!vsjk_j7N&?pVZZs8V|Z>eJ(vR>^X|a_H_aVr)YDxo1LI1Fn~) z$aFq$LuONvo6l|FMk3RPuh3P#IBb0pp-;}OfiM?>ATcC+fP|!JS!~yE7JL~?ZkzaE zy%uCALJcNPHcupBU`R-EI~W(DzxIiG%Afb5Yn|IQ=}@Hg==tl@{mrnpQ==Bdyr>6K zIGS=fp4^L9lBg=pV<e2I+qREBlQyeBMlFI{% zZvu5~-AYX-Lkc83(1;*6M~jT~nUAdEu6mQXqAz1ZM|hHL1*yTE(L1Fr&pteDcW(+) z?MK<)CuwXmTnRCzfAnaVxz3l$eXg;GPjU$wO&AEeHPd4b1ZT@60+kq8g=qZz{PY?h zAe>fuySm^3yZ>~tlObYr7m|O*v2C*sM}dlhl9;=bh{TWSgYOBp`|KHt?lYgTzS?N4 zOtobqIUw;m)3IxCkQR*$YL8xYOK0<*`M8xQGHFS!!Fe4c*{3o+_@W0zF~=v9kLcO7 zYe(q7YU4?r#jWGCXLDD zfw#IDQYVQ}r3chi7Zw&e+B#<)XIf8`S*u62Zur{}zUIu*CG`HdKJR;KDXASlG&D5+ zdAj6nOQ1oq$NAV&dUIi+we{5nWs9`(4c{vCOH37rC?6=Kz3dCo#|G9f^bwZ`%6xrp z2keHGDX#JHZHUp&)H5fHnav7}H zS4$K4q_rJovU&P&)lTE$kQ6pGvWTB^=ahGqIRTY9Vb7?iu9!YNatY*xw!^(5B<#Lw`<*B%8UX90I~YN9-l<{^vvN!wlupT;rfiTXn^HQ-a7r5nM(*lxfsgR0)nB zmG^c^58Z}7^XwXgJLOz89U%k^ebS}c;_*}>BVO5!_d!4@DMX1rN%G$ly$^|rhD*|? zBHHNRc(p+nvBCeK;MyU;z+*voocL(l=bS>Ps*F%Ols$G-G_WLtSdKA*$eC9?WVXeA ziHIxyAu5P)x)w=VOw0-)LF#KmQzBcVNs1~7)(a|^nWOJKq&dtalrqi;3r-$^?_1G` zWiTI7267zl8rvCoj{(Kfv9YmRBdn~fkh+5`&%Rq%9$&Ylg!mt2t*i)HS|YZiZ1=Tp zj5#%$`3~!b9cA@TtflbF6SL%r2y8!xBdpq+`rfj({+Jm>KJMl97kS%!JVhyAOg5HE zahU1Q=zxaw^!L5s+ig6-p1)Mv=*Ft0? z$Qc#L>&`3eX3yRp_`)7uH^%!jN8pigLB@UUan_I~m0w?2Y87?>bXK~P~ah%295$2gu7)u*1 zIB1G>-N@&9yTRGXX51Cc`}l)A+}Mb27~c$27GItM(X&k7s+9*sRpBYQV;@*9E-q`i zSSsd1S{?#+62eM~W4F_Z?pXV{%NafXra5Jl$PClV%3}!+Vj!;8$(K`e-$UtIs~q;x z?ri$oVn&_DH{nflj99~aL5#{Y12fNy<>q~s={JD*>Bda#*z4v~A-!P^tI-cda=N;@ zvZt1erWUV>oTWXw{7WX}RHjhgVpj-KfIU)FJ?4ZLMla4*Q&|jK;ej$#Ne*R64`sOT z)@{hBQOT_L@B8SyGuoXx)ri0u#PU%fGMIeZ_F7;U?jzLIxg4aeC!Ey4OWG{F9QdTO zuN@0>jVEQNB;7BhAs6OTiXK#qB!5qvTO}%%#S%`(6Y(LMf&R2xpcA6e@>OZVqe!I6tu4XtGB3a6z3 z3MHy9$b~ZT=J; zmt#}n0v$hj+Fx37!urdSzT$QFEMp zk1tNSuGxw1wbo@Xy4aU^88RSa+%O=FXD0D6KxpcGuJ2tl`pGIIW@EkkgeSM^g?-V7 zPdyB~PuE`21pwLD6=(13b-J~LzQF#zj*@aBnDzkq(ffJilyz+_aD>!$%J-1oC8t~S z^iec0JV@+uOC{TyMjlsVZH`a%K}YwR#->0cX0l9_PI#)(z~_8?g!&K*CIHctI`d`Z z(God$9gl+AE*(rh(>%H1c0pPx&;9PDa_tgouhZ0^D1 zuJA%@(AdK`mrYW8LlXCFD`bHGxy2-gPNOrfM3mQf@Uqx6d<>k+{l zHGgwz$Ier6E}7nqsfV^#SB5O_rf8@jhZYuah6*2T^nyCiNT^KXEL)?ufqu$dGG6EF zIqv)4g8UiwHiRXJNl78)-mo{#!-N2JuW#wQ&P#nf;b2miM%n@eIiG7PAB|jnJsB!; zvC&AnPiGWC`5_ysGI6Kj^WLnS)2GBjqRsbir^LTUGpvgd@Wdv}_BJ*6j(sB^1-za3 zzPHd3t|u zudin@?Zt@zEDZMy#}Xn$P%jZG7UY+sgo^+{oQ=&4F>eE%p#mvEAaE@Tl>7m)PX3OF z>6<{(D!=gF;}wF*<@_wE%(lUkNI1U_Rf8gK58-+*x7eJC_;6-I9-sxRPmA-QRj;fB zW7N8)cVx@~!Od1#J&?2T+|k3S>t#5l~LW%Y0lL`IZP#-fHJq zq)EcVC6>hg$`)R)V`W4=9c{n*k{L(;=rqQ(oISB+wSI!645*spVP4V$y_Q6Zly3bI z${Ps5&#vnfe|l}95aCu7Li>c0yUqe5<(Y@UeiT6&5vEWiV9l0o&a8c>u%KXPkEdWP zR?39s{is|1-ir$I<;#}~o16F9*w|jF7r%u-X;~)R-O=}HntF}%;PnZ=>)V62JpqrB z4Xh?r%ih*T`Dcb1R^H><2_5g{lINF~+n^&kG5e!Y%BvlhrGfezqsOZu%_T0tZWcn$ zpfZGs*;&ze35ueVCKLcaaXmlcJ^YG*You+1!_Lolb>BaB~1{xw5Wo6 zT*KYe?hohC^CmEKO;o4$1V6mP3*x()AVm8|i-7yrjg9J~hZL6{fZ2)>K+X=NxQZ0r zuCA_f)e5lR%O?3VXje*lzk-(d`kv^UM80}S|9QsChM)u>^lmg=?cjd2rj$YpWc%P| zj>uI@KiPMp0tIt$^+3W4tT@lmJrE6)mR10n%vB#Nm^ZYs`9N{BpUc`{`Wlds6L$n^ z3C(?csE?RS=%t}0^<7*Y{9PEmUB2?-2y#wx7E^Y!T@$WANpPp;E7i%wzPdx`N*VJ~U)8z4EGXa!JSxiYs#5p5$=e7)Ah6oFTQf zwZiX;^cx9(P1g^Ynwa<@%f-a-dmNBg+Rk*#pG(sL-)yW+*RZYOilJ8krO^Nr5M^{l zg6cjxJ4K>~7i5xg>Q_I{mzSRh$M^SI4Y7Rg;_TcC=!tot zOc6Eh%W1kIR9)E`YD}e1D@E3~&Y+Oj%pWK)EVGM@O-#yLJ_dZ^3f2rlQLj|)8i>Y^ z#l=eW{}@uoYyhDUPP}PC@^tlq1tMK}^k>4Ex<@Fs)1MCl;7s`SzNNlgE%P{BS(~b2 zl7GvcuIT$~_8uyLH1&9YU#KM*3HXq6Yl1Cr3|0O$v!BhOGT%2Kv;KOFosI1!!dJV- z%U2(-fm*!ri3u8GGb#1&zTnqwxH4s&K*A)Nd*5hh@cjq&%g$(^EC8te$K&oii-nUc zJqW0t13n&%NqmdA8e$0UZ1m~8W-p-;hellNm`EOYs@(-0X?5S~QU$IZNYSIBy$nsU zIn<-M{*lfZ?>Hut$cj^P4oZy7=GhqgpcE3Bv96D4w~}e@MxVd!jewUHVY#BJ{+Y=}# zsHD(70#8&Tq5(sdycn)gH3-7B>B?QLQpk9J%UQd|WiZ=n2oPP1QCzhE5PX(Clwxh_~rX0P}DvlWpywsNq2PC`T}Dg_+ZXx zFinWHkJPT<<42CE?nKH^3gO|_R_LP>p!#FULZwWwwx))qBqFB!1En0Z+00}19+y?V z#gJ!DQ(Ak4%QWFDG22kWK$Jd;!tO3u{n2qkZw@kxwX2W`00@VyruLy8MlGR#42`G_ z>YqCZ@DR&XPKXNoBHu=tl_8N_XJ(qUPZlkit*P%ej`tD2NNH+F+L=j63UBKne$Ydb z*w@EFM-^yF1_o;?*@*J6Fiq`E=~;aS+cMua$=?CrJ0a}Ur0iMgiWacWS`^#QfY`&^1juUox|LWB%*3?YI9EH^W{S;U@xbm-8!Zat0 z`t2px^LW!<`xi#TI&CSh<+0t0hFfe&G4M5~`fHsUKr}{VqPyl{( zR+oyoEP9&p@P)ivL0VCq9!FSh@0A^|mAsQPh0iiXwnTZ>yu=&jq+4cRrr-?aXr^ck zH=imAW(gpK8O#oo)C-tikHJ1wWeFBBBYOIoYDB4Vvg5+!u&UN$h_qE)Y@rr(ttytt zIz{~3N3Ssw2RGlmrjp!!^NN9ZCi6F6yp~rhAH?*f@YhNv&jIyxNO+t}SykxYn)2*m zkwt-Ts+iXbN)RJ}$_?0SM<2;|iveatA9RC%dUvkoI|}_Mc{pKyL{M#eP162R6sz6c8CF`Vdzk# z8yPw!Bm|`d5u}mskdp2}VjLu;rKOQnI;6Y1eGlK?dgt%CYwlfh&sk^h{XCyz{@I3~ z7%VItlT{Cj>Rex3UdlgHw#JWIko-F6B9T$g%`QhaBI2}0#G3c3*yzVNCto>UJ5p1W z{l4}ADlQ~Zb@%si(s}hUliiSCF(to6X@I`=tytpXi%7I9x7Ao)g%q1~Mw@rc$*QX( z`y9Q=0_Bap2cy}At*U{;$b7z$2?Q$M+^+2)7xAC=u{HtPh?bW8_3mO>+>;-bUF;jGu{feE%xq}P5|~#c*;z* z2VG~!9^bHCK3feb@A7hs5=D_MXz9HaSyhnwNNyw%{#5458p)tOenCv07AUh?&y#CO zs+ZDxLrZ=kbU5{p@-uFok&RUebE&|kkhuO;47=u6hn`Jxve?KSyMufjB5EFV>FHIG zCF-4dItw@Yn@f^fN-LVbb7)5_c(?UJ0C9ZP+C%-VuG zGGFv&6@m~hPvom@!;Bf^O}mcP@x;)|ZxGwRQ5p;^;g_ZS}qOFm2v@v$jL()}qQaBcZ%lO9|ZTt>+q|f$Itu zpo==cz5NM0oz5}f>(RwyS4e%19&z^Z7YiM@04l}R^IyIzk=cUQo^|Ji8O#9lN3_YI zU!L{3=u;&PCigRq-2aZI=!iXRV*D|V;lh=)1lFt*;#UqU&Q*`lV zQc9cy5}xIHFld_STVUdeWf$@W>BY@9wx79d^VqH^A5lvfzav!=k)SsxbvR3Zx&pB$ z^n|xp7j#Vld!FoKAH1uk=TZ~mYcWfJ4CT!{l-b!0StzD+bZvu&llbL~Mk~2yG8}s8 z`pKS}(X7>`4X78t)OWJehew~s9WSQM_i5yX9)Ecri{P9+xZn9^;01KNkFsvLqVt;L zZTa6f^_^u(kfFRU!yHTZ=rG<4fAbv?8FE8&t`PiN% zFDK3gEy2KW2bj9I=h(Hm;BO1_^DUyLj2`PYf2O^bRpD~9`jf*6Zbsn;*>NFBqD`sn zCZ`)gll}e4C%xQwY3ixx`vZ9)YopV&MG-9!grzL4d)=U)_fq`p`73ooqaHkI31uI^ z2~kVx6oDu0x_wf^=${VfjZR*2=UUdCZ2vwfn^Cde*M`9|u+G)P=njhM_~h6z7glI} z=TKxv)+TDoXJoyjzn^Bd4o%`-PmQtaiX8E@KkMI`diXX9CewLy9Wj&1PW=eN;;xv; zZIbyx#ad%qYm%~VLDNw~@}_QSXr{uNR0LJ&wr>WqIdt=Dot)u8)K)paU_2NWiHUj@ z&V|D>|2AoTV>G)V=12A>M%L4ekl|?R)PLrkQDwc%f!YWW32Xvrb4`6^%~hzfR2Cta z3l9J2vm?d^b#Q$oLdu@)B3u;BW=*M#3TKl~FjF7+`w^Q#M$E}>T^nN!a)D?(mI27J z*p~Qa#1TRE_jtk}-VyZog@c(@p22$^N18_cYI6qt1T0>Z0Co5V9Dv(`s^<+oi6b(d z6oT!%hX_6Q?M#XCS<9|pA!ym(dxs|-rwsDhDka9rWgeA$DtngNzP=>|EZ#IC6x&8d zymsw)stYmdn7xIgPH%kks7d3cn=hyb-mv0997Jx8)Lx{FMP-ty#H z|B`kaB?~A7?`0??5K0^Uk0~b?YFGP^AK9q+JQqP>2Fq|XSl3_*D*1ffD0>$jsNEeF zruXC(*1~Lk%Twy7IUZiy&stc5xubk;&hEe)ui|EKb+x&)-*ifw2&B{Xv&cebSLWN%~$C)Cm)+HDu{ z<`Lv780Ngm!F4vh_HdE7`ZZF>Yt~O=x0>0}`H%t{`gdhO;b?O#x!QqYLqLZoFByvg zcxxEnq|w9e5T|x7ckT6!m&OPR%(~Ek0~)Ohp$mt>_(Pd@&4Oj#U3N4jtipG#Hvicn znCr-MO4jgnqC2^V%46>s{NZ+MudT!RxU0yzBaY_>yor0o!c2zYy^uUvc4y^~ue3NF ze9`D+dxdNhmfSt+ky7>U$3$|*G0iM$r4B7vZE_0c2M~#D=Yh;{(f-Yzu;!>NpN`Sp zrp_fMC7zofyGtL@ZX>MFSq1zvV#>5YIMqVuXK<-$3l=E57ena0CDp)R-w+=pdVLG) z#5=?b?}+q6f_O%kE|4+E`f}8*Jpsx(|F(@yRJE1n7~!XM_Ail9QIj3X2^pUh222+H z?4euV*!di6bX!U#|DoqXhUpW&A&*3GxG#{Ol~e?NCB^94S&@HXP`qjb{!K)J1scj+ zsxSZ(^4(TSYAsWRMW#Q6^ken07)(2$FVCBv+O^J(6b|Kjc`EDGLkTP27LMGto9`_B z-?cT_bVBbjz%fKFL)dxzfsJ%uwc+h`?mewinH!+I_^wO?&;CjhaN#juSh0psOtaxM$h9jc{pWeO*R>{NDywoC>2xrpl z=0iO9JUhULQmuVaS+vQi=}q4WZ&-!}R#Z3<>?~-%rk@~5vtxTeH(Nt|IaXx;l=i{J zfs#ve_``Oh(CDa+rm3p=u0&4M?t=gPcMGJ$50`MkoK~Rvq3?XILV{UKr{k{XG$X8Z zL5}xs7rX(S3SIpCBO>qtH{Px}XpIKpHhAu{B~YT6B)TPiFDhu*YKsQXlRq+op7>xE z{`GZITHH<(LA1hPA(Y~S!PF3*bD^QqaXLvpdh|s*#n-s7{2x}_?!h@a3^)Q`%~{6J zPdzAtMWX$OD&2LC-f>>KFsTr<*`}nZbxzvpD6y|`Xo!v{Y2AQFf zvu;cU66IA-=&NgWH}xyb7KZGu3t}qb=7NBJl1!u3)pTPr^?Y;lk`b!g0(Bd5@+yHg>D zIE1BBXkiG{SMB0Bz~SEdj|TGMX+{&c9zru z&fQPGe49A#;%xTTpncUfg}{mrd~4u=4P1<5NhYL0@k<~EZ}M^5-ZQ_m!c_FyhTtte z>4|-LkAgR}xl%g30=93u-r*V7ZwI< z>#Ipr3!-TWl3skKp zD&gM>t;DZ`%NM__3Q2?qWFA=pJlkPs`WXoQ2n(wO@Op*a6b0b!EO?VaEJi>X#a=*% zToT=MlMfw8=BhdL@pf7p#02|`8CCIYWYk?a* z2NEWf+r&T|z;_(TxK5TJ^%L1+kK^&M6VP`nEis0X->vFP0+v~qch4&-Dk9&cfh>mc zv4?(YS$z8{#%4X11|@R%0#*|wIyN?|DwZG*sBV=)41JSJp=h@`$y|q%#Gbj%IC_46;^&AAtVR1 z`82sTt&Jt(u_7fHS?^YOU8_et?-TptWbnxwQxrc;W^l8bW5>_mqmLbmG;^0phHn1s zl6-jKv8-NfOSuLt@-!^Zzm|`;jhHom^Ti~mM^!6+4K;6R=;oNsF<1^qZJKt7=IeYz z%o@T`)dE?(nBwBu5>kzgxci4Ps3r+ge|>%Z(yyBL5>HM_D=LF(D+NVj-WYr(QU3MI zRI`je!c~yNgce&~e(2r4R)*9|5mEib^SlEWm@rHHL zqlDvG&?b~don>8~o%!Cx%pD#c1~lE?TpQ#-98iKkM@K0}M@P3Lb@Rad$-*KS7^LNH zKtb&N+0}&`92`72HWpG`%+<5Jh~s5REWyXmzrQmh0s?gc>g&l;W75-qy&LtoJmFke zTqFjyr(?@7&{bn&aUl=|Jw529$bB0H1%*E-jR3p9pHa*XMMeJFqg#&Po2C7bErH3ORo*f3HD8? znl&LQo{@qvuu(A*77SkSS9Xk%Tf7nv^mZ zsSr_#O8?rNBdL3d+)Q?-fO+J!qOa~ zN+NT=gam}8g}2Wik8rei!=6wN<&Si6L2!tuiwg*WQ*rAdY>{rBC=VwNQDyL1%gxi? z1^f#(gWvi_;FlTrzp#*%u(*{Z5BO8b-QC6B)ZRu139d$6TvSp(R03?~*3mT9)8!CR z0e`z7o$bLt8uqr%DCjFHPH2=X*rF~hEG8fX{SRz3LO3AM$VI1s?#$j3;W&R+KCUiq z+V(2Dwi)Yf*R>N9wK5imIs_c!ZIAXqqTH}IFDxJ~0JcIqJpJ75=eOF~dn0W@gCZQF zY8*mJU?21&s)FrG@WBlME@-Zc4M6Yk3(ubciS5!|rXo9i^}Y4%y(IKRj3ixbCAWf} z`1EYJbGrI!pb_p)dMLZOj5si2 zVQ{f?hj}6G>^-n&o8RP#Lb-S%-9O!Ei*j?bx5bVx_KXNL8s+opVGbx4?A>9HaR)d4 z$=RU|Mz#o-&$pT)?L2{jKm#BuB80s@=x|MYq@xqAX$jmOR|M{8?4ce`2s@O|{0ZhZ zEgqKnj>L5*8ifMqTQq8mcjFFn^;NNVfy@-wk=S;CHGThg7rFzzFiZZ;={!Bm%yvm_ z<9C!&adg!3aoXu2?$0kJJwJfn2p2DG_j0I+a!5;o`O2Z9#-S_@87R2=h2Hh>^urB= zy&V`A>_-&Z(+TB>aznVNecY;S>xK57a|gm;7JocM7lm>Moepi>Y47RjhqW<;mnRBr za`JR_!S1#9MS7Znzrh?)hW_5R=x1mEK0ViGmoE;k2)ZcT)k|To z{`)R+zL78X4%RgO-T6fUB&pB$!?)k_;>*GH9m>JM0|&w1dW&<71Bk#`%->%D_UJk1 z_O&BF^MPNSUBV*3vp{3C$nz)xrvn(koGtxYe*<}@1%D$Tj{Wpyj{}@D+Hb*0K_9k3 zzG(hi(8rHooBISA4dj^sUGPK#R8w+6Izptu2Ic9Aas@kV>iNLyATi~doEqHwyvqQC{)4T}k+iwls$K$NcRAi}!frWuEbs2##y%E1;K_xZH4m9n>S0DC@1BCOB)40683Qojr>OGaJ{ zr?B~a6y$J0x6>ji4z~>l0gD9-$jrX3lKpu0kJSnm4d;9ZE*{N~l9uX673;qxbh$uu z2WV_|_Gp~BxS`x23;YI``~z6Mun3l{e^oDAIP%~7WGvAT5)c!U#H!j@rh~JzL7g zlVsl_RsUB72&|*{j17ti{b&IKB(g-rmnaZ8DEP_1yx`lh$b$po-_o~>ej&j4HI(#Q zb0Kqp@a?&fpD#KrE5pTwi2t78{*+nrQ!zqh2}Uen#`kf_@8o*FT(=g*#n}I}kYbt9 z;%{dhE`awQkGzN~KJJk>IQ3n(NLTZ%2qDsS`SQrHVvgOw6~YplVA@r(2jtZ6Tf zEdFqLYpD|h#bxs|;P=zFpdv^pjktL9mW?|sS6ui*6BeKQI@|>KTFKZ#(ZtuoNKvfo zT1JWh+3+tASVeL1#{W1nE@$3XNJt8+#^SIE3r#7g){2^w!klr<1g4Lxxi<>jtKL@;^N<55cs9U zKArOa@ch49&?kwt;zeBfV}a6txmf9Moy0$n{%}rmd7?iNVQlwfjT&o$%L;M0>Hk~G zijr6k`k%+Ne_CSww`TrBu zVFSM%hT-CpMd8`k3U0B%#1g;pXDFCBr?@N_29|k!ieb3YgyjJ-evwa$W^k?l-l+D! zPzdq+30}Ae;s3Eq1k}m?VC)iFl63or4g7Z4_4#5EWw1!Z<|}Im{}78d|*A<3_Z ztoU&`#NQf)aLXqbNdsSCO@DYAZ7HSxjmyMEvHI_#@tj+a_9r|e;X5iR|2=wROSYkB z9EgZw{WsRl;6%#hA;026(x94Y@e)fhVeDs|=lO4ox?giB?B<@JP;xCIbcKm40L?L}hg>Xjet1toDM z%ee0OXXpQ=KJb^jBQdOFUBsWhB~&w7md8u|24zZ+9Qnnk#rz^S_zUD5$SMAU`X#a7 zrSjvyn9fKk82KV@(Z7FG2B0i8GkunlEgZ(B{|mr=R?c-oEhP7N1gxW+Ki>B zBaWMrUngwy;>_aTi)9~di2QkN#y@M|J43Kv?gVgw=AxkKtM!Fn$TEQC9)C#63%3ya zC#JlxCI{qc30!$67Fe)7`;$@aw?-|TyIf?jUx=Rm@T`}_l4Shvob~!V!TyN@`g5{h zIIayIv9%Db{#a=R>$|=|Vg6z9>wiXJmd4>VE-b};#(BH{cBV~gslxn&>vyD=D$M`v z{NG7$=1JFLx`i8!-!fvtEe8K35fk)`*zbN|_#gR@4BP`ZzroL8lh^;grr@VgvY+~t zjHT(9|4Fji-i7e|vDSi>`m}8sA zHvW4<&HtB}i>da<&r;zetlemkHalobH6oo;OOxpg$$d!TpYLj;R@MO z_Pg-7fZy-&N;s`=(Rf1s=Z|=I-VcRw95{bHF2o)A@8iOczU!C2PY)+};RLYX>xFv4 zLQ54eTx#n3*v03P6LyG%!GAxZl>Ivc;?J+=w+0vS|Bp+)GcK8TakvO(vH1Tlo&tw+ zU|)B2zo2LUGWlOnMJbL;-G0Yo2L9v2n5_Czp+L?(=i$w|1?# zI?BZ?{r?}M`I97{?{J{lh4lZPEB=%q_UB4I|D&&nl)&P`qLAzR-dX-htA~;9*;$O3)Lx<^r<6df)w@fmJwHw=7sC zBqD*;L9lRv^|L=^I(~$z-x`Z>jC&EM!X|qb2RwgxTH>=D>Tg`REA|DwMdf=cZI*|I zd{Q@ut)WweumZP36`I4iN`u9ZZ2h6h3EYd3zd^D4VF4!=5g~Qv*FOG1_&0xhsqoK! zeJSW5EQDYIi7({2|6f5{s-S%-74fTF9ky`qe^OgpD9FW4d7Rwx_3X)f z@PiA1!141T+qaa8{qi_P0?Yaq@xbq1@bhPcEI4r{px;!)Sm0I(R5NmBtk-V6FK7^)xVpk#@j~g30{l z_u%6E<02k~Wz&n{=}*_w{-D>W;$CCFXh6T482=uU+Ru+bu`}Tt#I_$6f&O>IwuNWo z;3&awRtYTjvnzqcmx^uwtkn2BGlaNs{UgQtt%D(nFQCn`1w%OPej6m+gGZFlFOC1Q ztiSNk%0F^3Jg&a?8(iHF%Pag8$>8T!9V|pGSf98czWv!z3@%N*m{WXLj_wap+OWQA zd6YJxrIl<@JPw8AKV{y0gcB?S!M#W1w?-G-s}~m;?ze^C%fo>G|8d?nvmneL(i#ildC#pk9h8Z(tE!ligcHAm1izcI*)FIe5`yDJ!};Ap%QEw#`3rVOpg}M*FAC!B z;pgkKOR4ihX$p>DINGCJ?LEenJS?%m? zQD}rG66MD4>4da(cC+`uULMj7>50S*($ZrWjpg4Sx9D;{InKcag}{!H8w%Gsc1RC* z7X;R{kZvwWutP+M1?lRJLVE)1nCn8E!2*kRoZIrXw+;Z;HgEf2CeE8b_KR~)@8e7m zf$kAB;p2q#v^N5V0)|Z02aRwC+nhXIfwihaH22$G&$nb7p=UvAEcgo{Y{&v$lQ!QB zmMe(ho@KCzD+uGZIbrP`N{!&Y_up|vpI5j1Pf`VJH!N^J?B_+}iHq*4Hb`ruS^sufAu z*`cAt+~KL@__VOti>j{+4s4Ne;^Dv1JvQAto~4y2rj?ld!7pV>E1_#DqBoIq?Ba*r ztvttKVq)0DsG0G<{v*U}aF~@^fs$R3;Hy7=BnGMv9fM(-fnWP!-hdyj`OY8;kL0UA z&?&!qK5#&d2);k-HivrZuU;m5p9hSab-IgVC{XLGX9lMY*F3hi820rmeZvS2-Q?3G zZThHIzHE{CyQVYA3d^vJ#zn6g@@x>5xdZ)of2+c z+^PF!D_%aisiGdq#810%YiVvpM~9iT&-0`^xh6WfCV8A!!fDSwlcQt(%R$$mIvF~^ zU6LO^*G`QC!NFDQk6yp7i@aZw8Ov>0`*`wB*r??Ks{o7yZ4QH{XokR_YY3 znyPR>Nn?y;sA{Y^+OS<^zxxRWQDpWZ7>&9}PL^U!Z{s4f2BhP`G? zyGMklpQK0X9@BUtL##J$>HqSkhRKQ8*wt%gz4PC_n{2z4ZOA6Q>mlaKIkUeuQAQ)`Ap(t ziOI4)q^hYSBQD#Dtb>}aGRs}LdbJr`z#!;6abAysMyj=vZb{U8#w~%>nv|FCeqNzb z{b6=Dha3kdXR}Z59xh`3^S&=^&ZZ4rJ-n7pCte^fJDWQX%`@70UjIRv^XVB|?V|1H z-IB=}r3_x-?-3UB<+`jxLPt-3nm(OWKhg2%UwNLuQQ#Zyw;pz!cTB6J@s!W4^?%Il zvaYuJ;$*9A?&G19-J1{i%KA-YRA&wh*zuSa?5I9=kbEuc__Gtz9$8&`-;Sz}IZsEH z(lmWu8nMSv#BTq56!e*##0kq zhUB)h4o6&is_qXU&-eJ}ku!*C9`JU)dFp_ZbWA>y(7s?t#i?gKMTq_RJ1VyL@ePV& zVk+N0uQVxOD>Y41J`i9ZMTo2k#GiOxIhGrseT;mxc9>*kF*!lU^X|&bJ#SyHNC_|U z(U*IH;oDxS8E1~2L0dcmQs>5gk>Lh}D0+I9W@Kh+bzkrr-Q;k65csf6U=mwO?6eYx z_q~60FVp3I$;8Xs`sQ-Lbs-cnQ@#7Ffa|i?iQ&_M-2P@3MXDyAZ^{Kj-DINJ9|FMfx8I~9Qk0m^OyA&V$j9|>w%&TQsH+7h+ zD!&YkSTD<)a)2tDnE+wiCwS@IDb{VrMf99%ck`Q{GoGy0OO(iYLffG`J$k-3JT)~{ z-@<6sF)&2$0r;aoIM*{lQ!CJvCNwnEK(ehI*;8zLkimJVHMua$*={CzLN6kr2=yS!!*XiW=5!De}Bq2Ahs-RxwC- z=zaNE^;^59qZ|`=%~2$njUC~%8}Af|9M+h02zFeJNk8k@ma9cD6WVeAbM6+jgP`N^ z6u=Y)7GW+{mKj3y{aT@Tm5N|;8^gS>Mqx5oQazZdaV080lS?8(w&hJTppD7NLb zWBPn}YmO`K+?l({bG<}AGJ2!#fwOgQDHy~IXqi2wFv-E>iEjn0s`CYdlg=GKf1j6? z->gWH$+s_vKt4=p!+L;thZ&`}0Jh?AhyS{%tml)yD4{*?lxvoj2pDo1RP^>(SA>KBKx zRV( zX2qv6v!*vKLHd=;o;#&0WG8V;+C7k zh7omOUfl*ItbfCl&50QSjHPZcm8C>nDeF=X@uF2%Ri3pN=oc?_Ox?5ffP9-XG?!^@ z?_?XUc>F+3c08&!OBqw0DCbvVc$@hCX|;yXurMl&{Ap-#11kYB84xfHe*c%;1%Cru z5%MA*S;;5r;3*D?L92({GnE`(HlS77kk-3<}z z)!9^A(^KSkKTuPBP`#&;ot=G6JMibSIor#2Myq8E-z(HnR1EYS>#+d7etVyZ|EqS> z!8lXP=mw_Ux|?R6u%(wYUIv7J<6sa>IQ$DU`O7q-hCQg(m-T+|t*!bQ;?8%ocWA;7 zFt4Ju1EvwDw11^my801g-;2~7+D{I(C5Y=}>8A?NaOxT$UO3I!*k*vv6}Q{1t-gHr zIRGW)V7)B*1o%nAA>|?-iwmWC?5?!Y&o6!0J3VntaImYwW6OEBZh}s2ef{{TD7e<< z&By#EA2+2(Zpv-8s`51)bi4oph`Jq-Z*}rNjZ}mF>aO3@oLN^XR!y z%|?sWjEsB$uS-z9oAMq6#ox@(iZR|cU>7|85&)aFi%UL4nCbY;tWD(kR3U%R$`0WA zMCpCHrBn4|2&o>#1siwT+pmQ54uPtszG6HO$8Vw>3B)=d7k*!lK@P~NM>-W zYyvd0T62zY!A`&QM=J+7wxa4j!{1%jd>8q`K!D7&>qLjv-%NY#E_VTNZ0lgl-M_IZ z?5-L%VguEQ&};RZ|B}@&u>C*}h^7dj&du4-CM9Mr4R%>`_KiAI$EPaJ_H4A=kzrOG zOzGUb=B{DKgj8FbeugT^WaNY0PLXrpSUiP$l$p7>G(<#1*hNKeBO~|j-3zhV;@5^7 zon5D^2E%u;4)Jd=jo1&!UlKDU38nP`Xg)=x)*Nt2nMUk~ti*08WJb&Z~Ty6^YA zdS*h&SnO)Cikv(a5C`o8q_lU^z0D@iU%GV3^+>T8G(J+%Ab{nnbufbN!@!sTiDbQ3 z@H8TJ^&^eCdv0Bgf^Z~(zTsQzmT5RspC921f=IvA&S80-SIS}=p9E^eW{Ny`tHw`O z$QSCoLVoAiwz`gO_M~ibCunk;W2uDYn9(V&9zC6if?3**Tsv?PBLj zoMJ2~+s1b+p5OUBf}3=0xcqoeT6CZ|tZ$D!Pd!2@cy()gPXpVmMk(RG&boZ6Q_b=8 zB; zXtkrL&*PjAn4bEnEg=5O z1djb#;At9g!kY7-7bixC+iu1Lw$UhNmsi;s^jzW}{7`S>d4RM+<4Sl%TQM=kO--86 z@`{Eb1vJcBgMhf-Q7hn9ty1kqI1783hPko|>&*T%`k2fiZ$?&Zh&bPe<7filcD03O zhEn(4VT-VBh}gzhmd2eCrVv(8=K7;0(1xbpD1U zwOfFP#wfEZSh4rX0@pd+TJ~5|t7eF}<~EDKW}o`2;q2@K^oRNQ*=>tEk`yy8u8Ix? zE~0N6tWVmBn|JM|dl=5Td(2z6IY02(pLX)_J@OA(j%j6L+vbdTewtM6SM>HiXUTpG z#wK^Qp&^lxS@qQTxH}sqt4egXS+PwIen`5zbkLnU#!VXq)+_AYv1+{x`$5VLbiP+l zlz>5IAkj%EH!j#A9UmFC9^7(~An-G<2b&&%(a<~W(Bw-*}dPJ=GX)n=Vq9fXHG?-_Jl}nr*@aX4?rP< zk0L^KiFF_J!ZS%lTm)K%TkM|P#0b59vw6;-an8Iv0O5u-DWFg5oeR?B;r?4C|e)w^|RtzUEInylV1@BQ_iwmC7B|DOM259?` z?0q}H6iK?4cm~Xu<1Bc>nlxb({&Xy-!=&rpt6QdNzD=Ay_+}RXPQc~H-9)&AkcBFq z&DX^gNPTreLc$ddY7uW3_tiZLTZ|Fb8weT&CxPYVJ_4f8mPl!7>1o~u8$Zj-+SH%;K$B)hjjuTx=zX9mTG2Ip`-o1hxsn+&t32wEEZBpc|1NmA@ty! zYvM>(h0$Q3uoxy~e}ifyll+~VCH6X2D)ldp26htzLd7ktCA*4IDyqJl+oN2T}ZmQy!h zJZd-+(q4p6p|#PWIZ|6!cM;LSNJvu7fsd)`JnI-aV|53UnMdpx(;(yXJfA+6nU^Ji z9rAd^5Dhyf=*)#*$QD?nonW02rp_rB_3BZGlpXh)em{e#dx-u($UoGukz`~gytcXE zsWmWRe_JJYyZEJ|`j$W@zv=r(EKM4hCLm7FBvJRp3`M5_nL%5|^T|o&J4#ZyR+p;V zgdYZK)mRWS_pVXpVB9_R+=M9@bXM~{`GIx&kHb)zoA%5Y7^CCK8u$9r0?%RTB`v)M zIFX=^j;=1X1a~`5`cu}97dYwGhO1|4-#@RPKaT-5(y*w!Rer-ohBf0iW6KA|QmxE_ zhp1H^N)K%>rd{B6w_$*SvR|5=$&-od(q~LSynnAAG^fl7 zxb#!*kmSd>$@fsl;P=6d5*Z0y_GHm)(r?kjOSiJC5%QU)8hz!N5ExO!yW z3cq7w`%kBFa`sd8QbLzF45)5N1dUE~U;|D0UW*|MSU)k72K>l`@4MV65OIWwxg6ge z%k$`k$`Iasdpk50Y2xA3W3?-~$0#F4M!GgnPd`=7yx`?=g=4_-6vPnbe*&cm_I-GN zzNeFz>}oW5KUHU7YP z<~v?F%HC%PE4oi`Dw=Z*;<7;DiWN`khQ)4XD-lLAu*%LUp6q>>VYBb=y^$pKdqaK( z6vLo6h;><5nB=UzlW`*6$wrv>^R68Fh;yjkg2KYR?`Qn`iBmLYH9O`9laK)DyX}Xx z{pxyNjoddH%LIgo3k}A(Q6&&0DyiA~IAAt5XdfRHDUn9{+zp6V&n~Q}n0e(@OEsH? zz1u)pc)54W8z1t-4HJ(alMxHOHVZDQhvz+af7L;RwO@C}pq50*;REUK-%bs$0|Kk$ zgcPr3Rrw>C%LyQ!stT_G8nJP?tPx@i1m~GxhHKvjJ+8{l`oaj{o_60Yj%ZOwG<_GG zwed#OMguZ3j7F&X4yVZ*hg#fXphzH+x(1a@H@j*bduwVNC4psG2O|}ez^x1t^oBRw ze5(?%JyI1~bHQuaID(B)Ba*2-G9RKR{lwCGbFnwz@K6K*5VNBa07tX@R2LhY`e8Vx zIU_o-5M)7xg_ZRcE0h`Dbmv3@z!uwh@U zu5ZO7;7+?< z))tvp>)`eKlgmfA&8qV-LvzzDX?HAle`S&s8f#=iYRm$g2VNq(Q8_$@IfiUhjPTee zWy2Lh96MrgwE(GU;3k{mHk#g1wz9T9X-;l;^OTy%u7|2+c+%(r&9i%Efa*Jth^mg3 z-~BpU5L5lAk`p!S<-S;8x!HN;m!Lbk{0ity*K?Ed> zvfjERTJTsvM>Cp3Fc@wY5)B}@Yem7T5*5MJmijnWHYY z^BW-gJ%F4Nync+frQWc9^Bco-^_qZ4O9wL; z{D^A#0kxdI4At7FPoM6@mywg(;iq{*bk?n7^ifT?W{n$c@61y!4$l5Y)hbu2iqUhO zFN2R(yj8693Z#t=Bm?dljqfR;ML;jGX7{?xtj8ib+uJ)jvv){OEHLUKdi-}QRg7<8 zIO@iI%p`-eF%*o%K?ei&i%qegB$0q{x<7P>E65(36i*a)?s-4+inG$=#Y43+v(PKo zL7sox0li=rq+IKM1Xlp)6$^RBb90KdH`jP4>eACI=|CsUn*rxJLubLl+%J2l=k9rw zl~N@!?LiGa?=Abub+oB%|xLF0qn*@*-9uiG3zFEKxQ2a z{#$A(kn8M`C8-r7pW2n+v&B$l*Msxt0T&1dD(;q>Z(g`i7lkYOUnq83y%rgefC!br z+)G=)6{&-Ys;ULqF#XH{JB!M3&EO~u;VEhlN|_;lv(D!7U5|`4%O5pF=8db|?P!kM zI9L-zxHGlK4?BRaj~_D> z@a{#!M2Br;(|VsS7Lbw<<`s3i!3D1Pqiommy)8)%)_H4$mrTrwasgt?8`ED zuZWlI-4!0EH^oz{o^tEHFjyBNbP7J(pIqwsbW09kPc=r`XF)=U_-bGmSBUf>14&h8 z-=npO#687O>|pLzapjrLexWR4BOj2Eiu=$M`@*MgXA++#fdE$^h4S|DN}$BIr_R2A zJG>$jkJG_xTHJ={V~+^|#(ai|GW^|SWaG_rR@1`AbZSrb)(Lr~rNI`op1z#~_IwUORh1GALxg;zEghBneSZz zgxy3+n%zkYU)k?X|G`0!L^C7|bjR#ilMp9pW$8zH85E41!SG^|&e~~) zN33WJmJ2IgQP;NVt!|d?2^-_V%CYK=AY)p(-RTX;1M`VDOwFl@XGj!_G}cHkMoan_ zc#m`#ZYU`Tp?O#evRio*%^M4|kKhBCZfPJJdvqQz7%l-9L-p*4yNpt-NbFb4HO^rz zE-v252(PwARlNwQdh_tmvx=rk1MMF=3&sYtlFUF z;5-{1%XXR=Fq==$n4WjDSgWwNs?TzDzbpvh!T_)KduU>t+MR$$mn!z6QAl-_w~#ZF zhFL%f+BVNTvOPNaP-JBM{j-jF-L@(r9Zz!{Rq6~_zkQZ1wkCMJcLZC7uy?@T7gy+cNn~nQBOL_vZvoev(;CzeMJGdi z>JEVFf$F!PWiAcM`OR_21 z$8@smb&i+Kae_U>HR)*J-*@8EuDoe=%?lQ(%0b= zX%8h*iV5aGDNAMt$fZe5($t;PteJ)N`9+MW1ZdPKtEjYRaffPAL>-P(Ct!kdfJ7F2 z-tBux-n}i{EL9t{gDe4565Yy;1Axf-{M}pk=UviXXAf;CRxH=52trH<16YF46QtN~ zxkL)F)ClzBK#6*!qm}^DeCE;tBWiu`{A>j?lNc>p%sfQMcOUEu5d)rP-)z;yn&6K8 zu&Y`5$lF=k7$r9r3`cEkrg%sIBeBNIc2lb+3<=H4(*e_i@jF?lV%`A^T5U$F%tEn7 zNU={(%M_MGqOm|SdV%Z}y@dP@VM=7h&x5!;DGgb);$8ycUHc8Z;1N+ON}}w8Ax@aGphTt@Zjt(( z{myPf1shkM5G zvvP}G>~>-FXnnX^sU@bxm#~-Eh*;U!Mx+`Au)KaHh~_b%{qO}FN$5~8QB<%`}+&U z2dQD4EzIowPMurp!%Jm)2(JXO0aXl>Q0KM_6A7qqBJ&IB2@=|?QAC7^n@gMIp1m2b zqkP30mnkDEDw??^Q(L>fgnicfCcGk1cT}#S)rZ~f5}*pkCEjysl5)`?P^h%=1u+$i z^_#u7>)H6iFYHHFOpb?+d4;Ns8z11R0qLm27eXsCkDbmrsdre4#J+#e45fE$PbmmT z@+gS!cJ$XB%mkr8vj;6x{))BxB=cFCMcDiNVD2ACg|*(U;x{~nvDCd)i}dfw?yyC1 z0VS!kP^iJwj^O_NtF_VnZjT&V&=({4eDmomHbms($!3o~vft^LU*QevnR3u8`-ocI z>FK*>SO%C!$mty5of{6xNkuuE(4eNb~A6{o*c<9Q; zBncoltwm&>#lL0^Y~k4V;S5k~;95Gj+Yfv0NjQmN3@fdrQWPNd9kzaWq$-~#^;LvDMR4z;Xna=qDFm|ywenEfdS~+KI7{FD6~&yS#JFCqfI+1- z^%+0~eMS^ix{CPpVR11Al5!fL2R!!_@1Q@=8ff=mT%vYmdU<(QL3>{Ou$OUqh`HDO z+*W4fO9-dSuHiRjM2jkRW@h?5w~ptd6WUs6dB)oXP5=_ zZ2iMY?{z;7;WLM~CqiQagnj9h zXRLGds^?1=6bW9lM;*_mIm*o4pQ0@4%jXbb z4%&$B?z;*7tsZ>j3~t_+30R(c?7W+Q0@ZK<@xD!dDv7!@|NenZt`EfgK&{R|+Cio* z6)NOlJPGT<;3wL|rSaXk`dgoAEs%hAR;x(Zy09&YRa_}kdtk`(HEOPfLiSj;H z^;0pATQK(ofuLsM{zz-C*Cm+)HE9!kVJHLNZbi9TDmuEyE9Lph$;PSofSLef7P(Ym z+mLAY{k<^X@`pYvChq`|bnV5+icx@J#%ttUcx!r(Pd=<7=`F`j zHFE|7%b@r_%$& z5Ss?xASjPkK$O4lOi(JCgL5EINQEkHJ+`Y{b~Gg&d3(oLMj#%M`12L*51Cf3vnq+@q$(YQ$4#2**Q3J7=5iz zwVZhL=u!JShU-E^S3G>16GVJvSV%+pSl)gLhC}TaLx(ID3hN+a3ppEM~smqZ~RfO;-d{&sg)E zrWRzIqh|?ta!4V1l}0^``F^N|JWK;_CD%0mo??szqt)OG(Jcv<$Kn`ujm;V8!REo; zl#@0>N~NoO32Tb9sM2mn@4t7WZx2QyimN#Zc~h*f8>2^dcMWjnY4rdilVN+O2kALD z*M_tOKvd+?hxdgr_n{L$mY$MKv}){<)lwq~w#3XD_!u@~O+&*A657Y^kZ4FTvrdQi zb26kocDFrzg|dEEil}461s;LJ2ri-8M=VxB4#a}#VaCUv)l-PnHupJnAAPgmK&)f* zKyvNqgEG=1wm&J_1yML)BA#@)(MP9V1R}s`($)chSGHx3|xkc18>rpVOOlGt&tl)U0_+ zTw_F?81FttTqn~*CRoL+a&XnF>dR(g%- z>^}db=M_tqc$F73BMXa|&8#HymMZ_rj)#^gDeaN!+Scvt3|A&W@jD46<-H-YB30FE zB!Rx|$Gji0B6uJOZ@gyb)M4=sMD017m^sC9ae-WrqIP)W>rqx&@6a7p*~7cuzLpWr z(W22`ga1}@PT}uE!ispt%I;Yjc_^-VjD|9WBqI$bwS0dG= z3V^-2$ohkBpia8ITk$$_^dOmn+xDY|#(4Ut50KI4x)Z?juCGOtD&BZ-Y2PK%{iac% zBvx-Uwl0#hmS#-?9VebN5tveruhuNgnXVu_G#Aq5)(O!IOl5BTx6UMnQYf@M7h3&# zu6{tX<_HOM>N*&^Z#xEq$vzTrS3TPO5u(@`Vo@L+5QSn?z6Fs8adp+FL8jDq-+;`k zCQum}wc@=h594`H>|_~cW%cY1q_5;;6yG^eQc$5*r>WZwVjJ3$6)gRIj|)8p|uX0$JuVpM}S?acoT^#y0sS=T?QJXn>U8_K6<`yRIcpc>S(!F&BeW( zBsVK;dYvBu-nk zX?lRhq;%hhSt$O1RKA2z=55?nNpMztt#thqIUW-#l^Yo3DR@7EJosX`{2shPDBvr~gde6DTs$@L5NO$IVN_35q zod5V`qKkVv7&kyZ%y&J$>if%&QT#Ic?(Rhs+6kaK@ZOO)LbN=TLb|kaBxpS-Z_JpK zt4&+${cL9nsC2lWMzs~8+Bf@l%+yj1gi@i*L{#J{dTnx(hxyFU>B`^B2ND=mi?2&0!E+7G*acWM?I2g!OdC^R*&e2) z>D(yfQ&279GFK|vbqH^wU0DtyY$grxX2x~T;q5`#3BMlJXA$0VA0Fe@Rh~AHQn!)B z<|v{SJ+_8lqn{L9-wvf@SGp55)03=;l84^CiZ93w#EUjZwU5?cJaO!FqiYetyepoI zAAAT%NLax#my%uS1*bV`la5w7AEoBK^{kRr^rUJh3Q*urcRvVq!U8%6EO6E|(A}cGMI}yRXg28YN9-a$K zmxn>3<;?cyH&n2$XHJdTp>N&Blze?EK241OEA*aA9fb3JGXysS3p=Qziq@@v z4T9W__&IXcTW4ua)M)WQ3r;FbzG`WpReKsp6Bc^pU_t9g6dh6^tV*k3yq2PonL z)(?56h0P$#)gF+wrt7`@4LX4x+(xH2(;m(QIR|Y`1}XQGL=hMZZw_c8XSAIoV=r<| z6c{tK%P26rc1-3)_BPRAoM&ufc7Jk%rRnJFJ**Eaak($DRvMkrr z4Yd(M>cjRTwK?Nu$A8ya`^f3(t<2mLTq(ZFB${!@g*u`;wn+jXww53wg~)m~Sh`8! z!@G2thxrJ=0cturPD=TgQ2{egPdn&afVEPPj9%hBy4@rTkSYnN3PLDP5NS>{YbtA7{vTJSDG0@O{*3wX_s->k9MSVp_+wX(+mI6 z(1I%Hee+F#!z7L{eaXr2!uZqGWBz7lo_BJV35MQ?!N}X7@FA)G8@p+ zLdH!2W(nJh(+Oh=pjznWctL%trelX529}hc4nFQVeo`Qn9mDRjscQO#r{bX$BaX2` zIZ&-|8qBUNl0Kj5Xbq-l&^t!7W?)4cDE&QA&Zvu*Xy^XH&X=8&Gxw6tF}@s_TpeLb z*PC@cAQeCEO%RbpG}3d?T*k&RRx?eNJya2H1To9CLe_yzI%y<&zBGyW!|i!oppyJ3 zNIF%3g<}ICKSQg~WxV_N<^5!~_as(q8m}SLUj1n5=s^ojdaroaqlaqF-hT&+SbsZA z!>limWrCtUy)}7He;0?a25$Tc=K?J^sFph}D&NeP(H%ImybFt~JGqG@Fk3;7MV+&BYWU9PnomFoI zioT@Bf<%pdth$;>ScZLXh;OOK>4M7geZjt%x_EL~#66~Iv)PCly;T-nF6bxleFQp0 zI({Fld8}m22;T4KDALWgv&(F5zjN*AVK%)f-&c>*bHanQ4{8r5?8!jAA-r|!LS)eE zOiJ1Y?fEpgFwvo~u;L4X$Fr?{Hga-yOgl6&Kd^3i0gB&wVmS3mjCXvP1)0t433ldQ zD(|#yBJ~86BqWAcip*B*;x61X(ZmJiZW%7>QZ-R;8>&u?WiPQg(3m1GqtBkvJpTBV zdXJU#)Z`VYGCE14?`+#9qyFR8<&Ge2LdmV(L8b8gN}t{fU}#Ma_4g$a@O{Di_aBM4 zal=4MPtQc!Sa*Hav&+)ZTFxn3v#PAwbAv?TcVr4(e zv$#b=aYM8{s#&_uWFBiU6q1KIgW&EEF(F?sW4yIs$J4NQtzJjf^gvT;*3EsgfP$69z>9x%`zD2z&Lvib^+{~b5dn4suB**t zR7vFc^Ln!)7-$(fyW&QY_`NI>PIhB7 zn}%Ii`L3PO_+u<$&7fx5*ldNX1|z3^@1CjT0FdJ;Sy*D4j^3THg_ZmS`(w zdoqw52W%xQl`I~v*aMI{I{Lz6s>DkutLT%=@tzGFM{IW44qm=r^6bXjqsRO#UZrif zCR}^tB^7epBZ-_1T@MiNl`S)SBlhtON9^SlZW=O{ZA=y`sMB5B^^W9V*|qSw+@U#1 zMn=Y((UZ=1vZG1gh8W-rjT9wHqEa^1ud}SHKfR*@Jsm|37Cjx%ISwtTjo6?`cza?4 zQV`Mt;B+hc*TawY-OALN0C}|GHlcVfl9(zH`)1911r~!K&?tFU0;Ld}ok>F6iQsXh zBK*gs*RETTp*tmDqC3KTU*bz{RAQe^v1hD}?tdCjkdv8+q2Ag9_ zx^uxtlJJESSZTjga`x5mg=uNJif9z+DZ}Br5s9g6k<=1TAn-&#PjcHE!N^fgAMlW* zwe8~Dn4#qKvtpDrJBU1()}U#SH&^9nrnz1Ka|t33%;em16oiC?dJnqmuvGpkENYSiQDXTP*a%4!Ao9^|2UiUg#I1%Z9bmW zBQsNuP0g*bA!(y_I}5LD$y1IwwY6<>ZcrSFnTKwssT?$Jq!}3yp??jNH6x;6$WT63 z>U2}NG)e1ja4iAR+-hOcG`z^ns;w1h`?4C$4h^TaI|5q={*SA-jH5%S{?vj#}?vifl?(P(jmKKoiMmm*lq?@z$^M3Dl#yLOm$94>5uY28V&3V-n5u?x! zitjvOmkV+Ia|IVj3kQwb0=?8^Z-1>WmnbD?d+xJ~(fmb*hKfY!=;5U&@L)3Rr^M$+1buhV9K^jCo- za8>^x<3?%g{G+{AgtWUWVKm~LE!R19;olu59X6+2rb=D2bQ$1a4M=>Cuxa5O$HYW+Kr~ZvLQ%KU` z;S_xl`@MzyCm3`p-98UTVL4BZ2Sj%cBUD4<;|bR;@N}T)n>~rK7%t(PFyMXn!d}$1N=nx0zv%g4ng(FP3gmDW5Lq?o| znkv8|&dBU8@Oai^74R1*+rBzrq@eN}*S^6ce^4goj^^(==I>*ny7d)>QGY=B%BH;5`sN}A); zD$s%llz3f~xRnNg?QnVdSVy(HTh+nUu_VMbNDkPYt_q2>p+Xq$bgQaAs@s11RVLA%>}oujm_vHq&Q&Clm9CQw zaGI4@QOw;m+Ab$`8ng6#M!>~m9SRu06i7O0J!BRow%;;iGhdjE&~pcRS*$IsIunfV zK)?HuwxjY>lPvHjym}*v?|PXYId70QR9M_`<}-k`(4n`R97HR0hWl1X7xd8sjoQDL zJp4YYe}(uN!GGNbL=X4D%fLOaT-xdeD zOcS$+B<9Ksfag;h{2(G2qW4~n->X!uwc+&e=wv}3^H5jsKq%xI{^=K$+WZpA% zerdao+aUWtK3Jg-WE2!)I*m4EVjV`@c&J7y81U$hKY;( z0CXiby+6}>=DIl$S#SHPo{z+Z>$I9!BIOF4&&?>%$2Yi7>7eQpc)FZ9ndXc=K{%P( za0wcF4`ixpyxPR+7p0{&7=P`+)}+7L-kU1iUbg2IW}#QhvFC>l^7_5QwX~SFetERW zZeR*o+c%`=fdquzyc93ikP=5M=ZC*P6emLWi;rjsFIKPDR?`!xz@j2PGa~?Ug-kT+ z+>Cvt6DpRKX}+2}Hr<#Z4X_@*V#Hz_xqe^%f*PRpz`jL-hHSQcJ8oyCW4g(={gbQs zBb5t7gF1|=g9uvFsk4n!kieCiS4##V3v7lE)oZth-^UN5=nyfS|92!ZC1^OxpwA-G z@Jw4h|7DuAUuPILk-jt7!D`W2_B%rT&}JvU0x?_Qiqx=IId1mLw$am0HKy>mygWZS zq4?%8;kyd5XxE9ac^)zaIK28)`#BhS%G~aTd>9+7bQz2HIxpPxw0O;w&#Atims&~l zx47O1hFx#c1!V9mTp)EG!sT={SNlH!z>YqbVoV4wj{h9Onq>Ntk$H&0c-W|lF%RF~ zx_j*9hr;ED0L@Jm9I@c*8nA%lI$L62Y2h48$$Clz!bX7xQPr8R2l4enX=qIX#438h zWY7J9d|&wGF7`a4=ni7;>+m`zVLNBg1Piz`)_TC8J{eraUi-d_{P?GF@h_4WRiP!% zU+RpXJ_-@~O2f#7HXCy*TJh$N=QY z4J{7&gH?L1DnBzJt7-7S;nW843!QyBK0ReYOolr(aL;xBw1%dsKL)ZCR3V1f?Nxtb z7IN|&cH#A(Zq>uV=_x5C0@O7!z&&yz-8n4)l_QGbV1!{KOZg=p&Fp(v(2K|8X)eFp z_-A#*(dDyte~)nhwp{WCFzdIK?ieDfb$DJ`&DtBVf(3?Thx3}>w<{hPYsAp)LknQ(tn?ly}Y>=P{1>Nc;oL6XTJ@Caj40R-p>#@+b!f~xp3wH zp338`XppEhFm9U6%^s|%Hop+X6a{~cCjD9T0bzV1+w0Xr1}7Zj80>hCAOjlM-)1LT zpW2~qL9WFh0s?}ZmR8!)a??b{2?I+Z@CB9C4Gsl4Y#~yHqq!#1D5U^-oz}1P|Qe)z)}5+CfrF(Y%cZzGAqjq!uMlx%HR z7i?8MpGZ+mPDUK{;AQy_oQroD)YhClGOh!3P?-U5DN+liaf6lxc*H3|7^>$tU3SJ zv`iw_OhCwWVAPXSnWHw>xIey)8e9aC=$?{g*h5mKKFu$m6bd?sZ2deH$9m)Pkun1f zuBcPwC-=4^PLj!z=YqC4*Ok4g$=+QyG^|kc?J17GHs{poQjYKmQqV$6N=fi^XhuG-`!Jzp| zAbg|)7rX{}fZ$>eN^UF&7x^9+iWs=^OBmEWPhVUxPQQ+Wka;sHfphS)KQpDqm>uD&zMp|MbQ)0?o%E)|y z$#opbbh8tU+fj*vk4}woc?I8P$8Qun>bIy5%O%d*|C8heDOxU&D)gD^%%bgxt1uEk zO?;P29s@on6NW~VK6{j?LPwMa<{0&yHjW3(!SQhoUh+Tx{Iik=`SYMjS*=4}y<0VX&O=#IW-gQ#(NdPo5ZD z6G+=~J-v>(^AHpJdM+s|E1U53++rz}%KAUF6qCuGS?ls-^%KZlar^ua8r0z48x+gV zVS|1_XD&2~dsfHuN}x$QUi|9z`|47OWJ`U8wEE2_-&?N~#x+4bo!LF6>h-R-NT;{; zPT2qFmoy}TwsC*HR9od$5+I+hFSFgnYveesMwW}&1r${?C0OL96{kpI0e592>A~eG z06@DPl@HvM0S=TtD7z_~33xzNdG(va4*$JXLHCCfi0OR1yZETDt5vPvZZcD%K*|m8 zy1)g2xi`0GiqS6KzJl9pi*=99e4%s{Fbfs9amGVJFC|Bjl9H|c+}FP2oS5aQ-Oq>~ zuSzXD;`Sy{kZ&S!UKblZ=>L!afJc}hjR$VrI0OPQ)nMe9V>nOc1!7exaHeoD^n>95to#d{r7i`UHF&xT^*njr2!sH&C9!3cp-2% zI1t}M1zk;XHpaa^2xQl^R35V*Ifd~-WSvufB%gfGqm+>mA9R^wcEQ|1_U^>JZL^K= zO<~k2t6WG!&W=l<_hTsTxqY{hhxHjQ0xnP9OrVDsLTItw*rHAYaz?Q#7=k2+X^>xZwFAVzTL7%1q-TNf zGAJzN?I9*seT>S2OTCppBr08+&mtCxP9B{7s(xA^P?XGRa?QmfwXu7j8 zrq3&h9s$`XE3qOmpaPOJb8R&-E0zoKlvnfH^%CSC%Zob6ipn%~o=`|79!mM!CyP(q z=bVej?M*U);(5CBz&$~yGOF|_n}qf2Z*h?c(o1IS?Xp3;ew{nrKecLSsfY)kahubr zWCn_*e!CzE`_-juKE~{>2L`zyNcM|A4Tirzhx?f!L|gxgibk&n1i6(9sDYD4eWiVB z!}Jw7)|zeRA5wYB02ss#FIrB{j}U5381-cwkRTxMZD@E{HT2I#n_nh_=HwwT>S3vc z{`%$DF@MovR*nkyIl~LXU-Nj?ZM~B32LRt%Kp@geq481hDnu(wBzsyu$oc~KE56gh z{^)k=s2KdQgPym8c+0s-UdwH?#!KxQTkj~06R7E#B??#!QkxmLw`65As0@@+e_zTE z4UVQ99TC&%x1}X9s_4o-JjmSD6&H!|`XNKcCKG3v8wg<-7>lFGAcojBdU8wTrb zOvBo&OyhrpCr?1%Kkzm(C}{7T@?(|Fy3yJ5aEE|4(SIek;U8R-C&Y&juc>zby#V-j zYltQ{3FH&1=zERv@9yg8Q!$oMvS6K4P6_0S$jU|lWUTMV@#)@$8ioK_E!lZVUXct+zTQ*~FDZ0S=sH2dmqVPm$sVG8kun5o+fzci^R_+-8Y4oge-IjNG@ zRH)Nx>I^4erPdY8h5OM7|GVg^I{TcEok*FUzL=1bLpQHBL%r>pp(~@9%rG=h&HP>h zE5S~B^+0sG$W_8przzogh3@!{@3!Ed4*PQH>1eIH+nhqx?=XJY5vWNx5ZW9EI_!8# zb3(fNDso4nfatxG;8VH6=A6g%*45q+j?P&}{y!m~d!YCIX9WWZJ5{ znCuA(or+;zt!Z!=!+h6_tk9;Mt}7laj9p4aKdn9*No)YO1W#_Knvx(jAT>$BjxT%* z(z76R{B|}WTd&vdE`-VhL7U7=e0LY7Wi7z-5grqxAtfdz7Eqw9--Spj8pGxelMICa zNHOjl0%XIGB+Wfq{WkafmvJvU01DufBJx_s&f6FTJ7&JgCgiXn;x`$Qy>m0U;JOUkhxoz`Et+aGl1KJ5NK5^^4_=^ z-U@wS|G0*X%BTtz065`;K5_-B?oA*8j?!0>uNO>!;_@JO(-cLh5=K{ee2*O{kMtpg z0&^`?ddn6bbM{OGy}r4bCmDt5)ay34rXOU&C8rgI%Z9+Hff&-tvTpMc2(Su?B?Y8d z@~?Ucv)tl8g0SDhwgl(vq$!66G49UgFoS9X7tl`&H7X<2izl+ZIXU(FnbAM~%uH3N zY4dkyq!4W+ZKE}3c_*{knoM2cTB?p#Yrn^`V>&uGea`0TXgBX`z%Im8Ir9HWb zkDcdpVL?*snGj5B!o+YfRhYQ>R)Rk3bhzYxI=4J}bDJM^nF_hCm4*f5M_8gXobb5F zupMuhGvC&n_@cJXs^U!*{-~F^cDIkoV@#=^`j>JjQ4OkU;A+b<;8H7@Kf)Z!B%iw! zJKcF=sV{lC^6IzIor0n#HASe&W9m)zp3v|Cb7x7{4 zDybM*4r-pIx+Y1}|Lbj3PZqgwB5wfF=frBF4rdT0TgoaA%<4n-G7Nf$JR1GJiG4hqqG0M_It~Wm#8Q+e^0r!#`ThHf73?$ z^I4C2;#$8ds2B&qY`F0ULy@O)nz!eKBb`P?Gm9@Aj~OURUQL`$Eeqm#YeCY7IaA%z?YrbQn%MnjOsbI}liQRN^(3oU(4%emk%ygs?NR1xMCpU9UyMt zcfl2KqEXDUFh*lMu>*%}%QB9mlV3uj3(Rz_F~zN7&&wC;3q?A`lWcR&ru(zd^S*1E zBELk0?f8?63l||!b-=%jbbA_=f9+;`b z6z}%ZthBy5>C`~QyxsP@^0ycT7W%${U->+@ll6j@8v*Xaf6>tIkb0@4_TUP&5T*FH z0roW%pXF`(M;kd_+5-?_DsQ5?gLA|m9kf4?Q8TH~aI7}dap=XS+$N!5WqK=~@kYrA z(QvaRpj05|esE_Lt*F3TktG4BWYUH|;oO%lYSWskcGBP5Lv#0CnBkV5jZ5e(W2Tm{KRK^ZMhK_Q;O zJ?}yG3f~$*!nXU{_>0{O(fRQj-u8*c#Bfl|DSq@*DhcNK!82011e0Y|zNlxCt$x~q;d0>|UJA7ge$VYRd%;_3xt<`=6d(y5zmif_j+m)& zvm=%ZyQj7}w`i?&ixWA@raxwA9GAdRZnc73r$QvEZ+aSE6R=XM4uYsx^Szi>C_2&tU=Npx-(crxgaNGozz%?lxdC9~&EOzh8*ZYfsO%pC4W09eXNjkY+I% zFaILHSvmhoI2ycwt=;jjD2e%)MNelw@!a%j)i-Wve%j<=eV2-pNkQ`3j*64vFu~PM zpha3}!Dj6#>8Hu+^K@Mx)y5k*tSUWqv8l4#b|*g9SkIiSlsNmN*{_7MTGC#Ystu7G zWlIVN2u;^WLc~S8n9WX_lCZ4;1o*^SbA*TVhe3uwa9&NJn;z2Glr8=}-$mY*?mC~b=J$=|A6?2ILArnr(x zp#@Rc>*nnWcbH>H%gH5z>zhv0^5OdUK)Q4MSzr#NFH$qEx-Q$&=W4e&wO;WkaR=4) zxZ#TZr1)(Sjy3xFD11JNB7sJ~f#RA#CKl6-ib(-rmhnT0ky_+Rns7pQ7v$Q0N9^g- z+aCyiKRPZt^5~>IyNY=>*wbF;ZO;9#7J#W+z@GadUu~++W5SYP_FL^2JAH+?DDUBO z_;>B`d%?Ya5Ot#_R+u6m)5CamqR;O9$n(YFd2q=EgT5+ZNTZrm;UxO-OsUhIuDzV* zb+&Dyh!IlnJE_5zAU~?VC8clyNG~WJW7XnxmO91XvE=^U*<&ot(doM`IE2RyNsL7U zz}wRU@7fP!&-gPS_V)FaNw6_GQLZci^ReC(S|c5W4v%P1L!f3_)7S(s#=ZI*MrTiD zeO!xbHe1d z=JD-b@84FNEFB#kn_(wZGUcRhn2J)fS;v|bKrpV6{F^Q^?GNBrTYd}%Aql~E2J~eA zw2ZRhZ3ozyOT2FMKJ!af3uBWDY9f8|Fn`U*0SyvoUTv;pA?iv(xDmb8!;w>(j)d1u z^276czx5-FHMQ*I%?7R`{>PSE+UMlltx{Y!n{0^fs>Xv{J*|yV6(b3=rbxzlk)f~7 z_vV*@V*d%-uwH5OFMS7VOZApO4pd79v=7U&fpLxKq8Dpk6a%q%V?;x@&*hnq&hJD$ z-E63m89YwL<$;`M3(z?-olJ;PMBIuS&|gjd9?n%7qvPP|Cxh^bL7;V*J-NxdU%Pww z2N5SwZBY*zW3Zq(ziT;T(^~xkoW=SaORX;7&zB}q$?pJ`n~xU>l2^PvCiJXp2~kn* zQGrQmf?fq_0A!1r0Aor*a&qxoKM-n>ocHD7xJeF#HweazfTV{?feIi8hiJsveD1%* z5(QSfrBNij_JJAC!SQ&oUG@Lz^5*?u&)x@YV(#Db^RH6`ocE-#4-%qchc$>3^F?2Z zg4dI)fq^UWlht)1SHpd|6ll5qV>%lysdx*|1CEmu9e*a*e7Vy}!ptcTxGMe_y$pWV z4r)Q^`Evs#zHePjMDBp?`2DYI3vLC#(mMaa#krX;)mws4K{9a2rP`%^nbDOrgyhrO zOpmf~%chH*X6452zbR>Z1I!KCxtpgzgwW!{OvM;Oc}Xe2x{>Vjbp?1EEAc`IFMd{e@&$sFmx*@y}xtW9ZBBgKduI)SfIPF$k zR3o-cWS;teZWUJE9C}@BRT_-gM)$WgyV>M`Z}tfZA4TdgDrAwYXiVWxX)Z_gj>pg1 zwyO+B-4i7;wtrz+G0P0!Hm~=CsVnEtmS%dX1T{^vcKF1T3p|MKewJAoOgEt=-ALnIZkp6 z3IbWf)jHEitXL57RMbiULXyP`LH>;F&B^Mv2BQun1~Lj!#G=n0Ee2qz`|E+rmb;W6f$xh(}6zC}+98JHwyBkhF6rF`K&aC*)wkA1dFVhzi)L_E)LfsVER1u zvjl^PTmVI}IxSQEUg4fSQ~A7 z*lBNZo@n0{vWlq*?NZj3*7y288%&4R{1LVmS}K@LQ8riITZ_crNJQ==C z0v6AXq&t5SjmO&*HYUg}%2B)T$1GN!D!~i<*LTwzyiRawY<v zkH6eR!1MYlEto2Ra+)7Y#PjZ)db$MonmoxZz(<_M$?!?=_11;}lb-UQCNM+8M+ztc zj3zlNtFnN#C-zr;@3-DST9^jHkx4^K`%NTak;Tu?Eo)6tSnYHEP#n%9EY<%zM{%9m z9@BjjL=%re{IteADvj)i6?J?=C~8XX@vy2f|BYiq2Q3ck%VCShf~%cfRE(U-cWSm6 zMB;53gW-li)pk0x+cbE@zof`|EKt+^Rr;^s(c#Xw zXMD@+s_PGBNy<3#I5vdP>TQ>CUqBoNRKgyu=3ig5+ogL)N)CKVPtLFUb#bH6QbncK zARGydkYE%7V^17j#8oJj3D02AfFPKUP?;i)ks7!dVC4UP1ODsQ*6HKrCX`lLurD8yZ!2L{X2b(=dj;T$x)`!YtHO5+xVvW z+=9B_jiA7a{BBfF6R(_OgBVS~+mnWn2f4MDy*06pmoWeGdsHE@Wn1jmUfw13N_sRN z4OlnIs!fMI_@kD3Vt$rql2x7`P}V!AR8Ym(;gwQ6;~B3WZ{ALpXMie9`-<`Hj>VM~;_vt+Gy%*KwaB(qHv()3f3Em_Uc&v^W9Fn-TS*bMNO zC$>IR{ILc-@0K*LWV>1N>#8fPph`aFt&?YJ_4u{L4AEhBaoDAN|jyB)N z_CKr#zqaq6?gC&D#7RJUcL`E?t2@vJ2TH=lI_-?^&++E-X%e1Epb5C2kr68E)LUR8 zvZ??`=uLL^*2$90pFe+KME>gVd0tuoG3#$U`(!D-A29I(2d$BRhCQl-S+!?Gm2^wx_o4eeI1fF0;exGbNi7R23wuA?nK&r5;A_+Yl_>2-Q*2vC zZlPwynm{9a(Z6B84j+OBn?MDvMNcKl*sLtA#nV5=>wk?VN2$iQ{)5HBfJxU4KQPIs zcs*1kTHbOa86FrKf?Aw+3jo-W#!ShR?7C6{-Qe5o`V_Z;2~8z?veNPo_=^gG5qQuX z#jSGLEb2hZvXh&`;r`MRq$VMNDxtqGGDL(G8TKvJs@~)6ztw^$V0w#KQlL{S=>?i9 zB6=DF5ebRWm!fdla*;ayRGk;{H7(N>NWg^%Vgz&I@|%1Ba?cvZq;fx{6Oo~z?(gNf9$+v!>#5dMuRwgVI;5FPylr_OxB z&c)Mit&Jec21J)Kf!RmKSMJsbw3#CrA>OACV8nx{F<^a#fV~?iKTyLGvB3Td1oLbH zoum8uS#+DpcD#WFyU#b!%Ve(6Vy%rAFleDb+R|SpTuHOcchhoe#^D7@(FWH)HJiWh z|LG);>p-3&h>0o$>N$!BT2ULie)OLsIjeEFUX!=Tt9Jq(jU}^t4HkxAZV&}~_;mud z*&8hBzcfhuu)0(g-y{uK=35BZ&BAUXuzajU0;KQTSX`qcF3G+r5h-qy%Kv2q#qkH= zgmmEGw|3n9iVxDLQTqYrzR2NAF)p5CiNSugnDqz<%)O@w@n1kE08>0i!J^-WA)$D> z{_C~VS3WT@d!gRmJA{9VGfvX!jJ1>*mwQ~@Zw31*k{=|vRIj7f(3HWv9jq~Z9Q=!$ z8!2%AFC(W8TcE{sVNG!CK7yjL_Xi}c=luk|Mg{GynKe+AO8}$S?o8DgC-%~u zw*>g{8uUAr<}|@tZ5@Wy57q5G!2I45kcAx{cdW>MlbZSKap6YG+eF|*Tvwt1PDVmQ zHMFirA~Eze?5{_$RwBc}IBu`_>)D?QBjmw#)nrHCKR&##`Jinw5+}gi)JHkq{$)QQ zPps}8=n>w;TrP;85J5NyP@_!fBx?N9U+B9=oX;PhTCz7 zIB1#C1P2F)J5ITc-rOL(_1>s+dQ>0R~f#OSxPNY?#__@^NI<0Bo6AWx(#-K)~Q0u|{BN<1_Ls`RUivN!fV%b8) zzLn0*a<4;DDaK6&aC;6#qV7f9g7qak&98ZV}t4?Ugvlp<*Vg4d!0gi zHrD8>FDF(j#bjevjy}hCV;f46x9Z#qm<+~2gMfA{%*d22Fyo*boJW20Ogi-#k4+%u z5y$UzsX-1(Ru-Itl>xIfC5DT&rw+$m9#dZ|U*Ij2n-S=D?KbzS)V*Y3m0e>tkF%N; zzt^0}X*qp?(a65l=9sNLLab-uC+<%g5q^YfV^b>pnb~@TKV2+SdBjjiw83zV!?K7k zzSN!6poxWvNuG5GT59Aq zKi^1@g2EPj0!tSHg!e%JOh zK_gI54CJ{HP(MM?sOICC`4o;wf{yT7TYQX;a0;^?JrvdWEcgdEfD5%|kaRQ3!@@pu z)XGrfExQ^o%N5!GJo?K#22S{SYkzH#IyRbY3wU*x~DEaz&wr^Zo|)LAYPs*AeaWr0h8u$_FHW#!vE^ z4{N%T2LDWr6!tK_8E183Q09yxQRV?~?zK@H&^t*I|E0y)@9kUGYP*72;#TGKmv<4i z9}}%Fi${=f{?_D5BzW#6X)0|6&uj;R=rElfhm+`-f00h}gBt>G-S>eB45Bb;^b=Ad zrrCW4(%u3~7Q}O3kO)R07WWD_nHG&l7S<~;HFC7jSQBmh7r6EEB>p{@YL;dewkjXf zqNUCH*EtiL@A5cByl1eyXa0a22A9zBglkwM7&sH1u9fU#s(rC$XE4OT4Ie*xtW7HH z%Xq}fw+zArnQVyV{{Ht=W4syMr9Ks%P^>C~#=kJI;i#PrUfOPD7Y1)8f$Gp;b@Sr^64hCUxX_^%3Jcnv0=ui4~?_2V5?7 z4-ui%_Ux?}YK(L<)pz-|z`N>HBc`{yar+7y5YSq===K~DVt4VSmsvP z#dT!9@~22MQyu@Am1OePR6|=1=OaF>%8LVj@Vi z@Z5e<#oOjXt2x2!qMQ>dppr{MW{S9Z1BK0=|Q<2q*9de;>V=!dh2v6~^8BEEGw4pkA8+E9f zv83BihGiKAF=UPx7UrZcqI%EBB9z!LoB!$({Qd-7tdYD)oYopDNdmz9?rP8p2DeFM zvmle0kJ_=}oLIKF8M(Yh#=oOeKKueLHHM4CA%Vin&Pi5Iys z0@BbJLVjc=CBjJp0WG500Gn5w4b|iA-JPYSPQ|CW|ITs_y(qpxa}HJzF{DzhI$a4+ zf5}^&Y;8g#uqp7tD54L0pM7>DTh7FA!nVXr`CAzzof1;6hu+wg``fMF?!P-A#T<-R z|L8Q8q>aEqk91$7KM4%3W9X6S1bhSOibVMM_?u~94wru+pmS` zimoL~ELt^C-usXte7I%1pn|4C{alGDyLHH@S(Sw!OWk<@E~Tm8F5Zh%JZ^Ad*kAV1 zpBFu$1JD|X)?+$$J?-tgmOfArLDr`I_3X|+OdE0>V(!?eip`KH6*mqpFcm)M*z<1`m>C_3|{0$FQWSj8!D7+Bo!4vaevqro~{9+-LmYC*9R zL>5hsqI57ic9^uMx8BKV`yyrH!SFDAPNG$lvbg5aKHZU^X^STQc_XT*j;r_wdqWE# zbsr3twEKVk=s+d8C|ZqBNE+L&X!^R_D7X8*3)tk~reu^dSR04#Uzc>B5)coy^OMZ? zMG#o@@P1P(V*Q=p(74W;b0d&bK{}dO_fBWG11tU$pq&BlC_hu;0WTCrFazM+)uM7Ii6hCaYCY@FFoZMZ7hyT>lr^3{4#OA4M0N zyH%7&b9&X^qkF@s#<|A}!7+x^qdrGO;V(`$q%NOJE~K>hgXqAb7EyrvYmv24kUmf% z_P#`O2d8{tEX1{(yjWzT8Ff$S?`YAn7?uCKGoM3frjw)Q5_My3cr`?)<@yDyN|1+! zImIPbCuuK~?HT?6G(WK8lVHoN{Vu0GCN^JTU)%})6w2!E@o~e{yNIt5whvYoaXGo* zh~Y5fM&n`lb~swJ2v`>1-7~;Q)@SbB7J^W~!8r2wNoH2gG8o|qQ}NQf2-dWueEV8N*1YF+?IGe_$R_yG!R@+0SVbtuSRN!d11B&q2tOT8cg@&cHTw4 z=AIdfj+`cp4d{oO86i+nX>)fR)Pd0%km*UZ{V_?Cy*sMKhI7uFX!ezW=x9P0T3P!? zu!?L^a?QJ=RBfecp1hvbg)lC+j>0SkBe%E4e;)())>;c1jP4(AW}X+L14u^5j5CV# z3sEGe9CHZqZa4qR(DMz0obp0TTqZq^A0Z3qWmQXzgFR;jty=niFUD7lBHNHdCSf$_@wNM5)ES0 z-R95zWBjZz7wA#q)63Z^xBI#H?_m&s5(R``yrz4Q?Lj+}1pA~8?30DxT{;|JcH%v{ z5~MN0p+#00@ZT~N&Wugjr#F7HVMcp1BZW>LObq*0=JkYNMa9jV)DfCb@Lhyz2bN~X z7#keL%_H%Y2w&2p&OBXa&{B$I^I$h*rI<^y@ z`|K_(EHLLr9Eg59JCUjHCqZF(>G;-$u&+Cq@e5Ls4D98>tF2+v=Pz>{_dv%dzV#}& zY-{wdOg3MG!c(6sE7~*=O7?F!nDt_}I9CR!u_bPT^zy+}mfYF*?+HU#FqnS}1Pw#}c>@^P|R)n8K=lk%PoqyZ}E(O6QX zA1AV#D>_E+&=NE12`p1m%YlyaHm+J%a^4C$zL~&;Lue*~NdlA5T(QfNGkA*aB@uQ}? z5JM7NtEHn2L*~dq{bf3=^ozXLdi*%!T$+(`$@{S0R)#N5>_px5{w*t?5+fM6zHzc& z?QIx0V$toj#2SSJa&8RXnpvek28=tGI7saZ?l+JySv^fwsaPK(z@qFIQHOQ8 z6;;d@wNAv8hSZB1}9BTq4syoeLCj`~JI; zA(1V(g$af@287<)@3Y^Nk&!J0OoY)prz^HeO7UuT0nJZ4Tg|NT$8ciq!F0P;73zK zGdrhCvmwuCyy;)MDn0f20P99k^*2@zw&wWJ+p(+R{8uDbG5q&dsSd{n7oTJR6=Y>^ zmJ_vK9SO*Naqb)X!AE#&71i& z+y~&quZ0m4Xas#)LCW_ekj9QY(1bh1ce9Aw_3p%A^gR`;aYRFc4EazVuT-RtaJ@B0 z*E$sCSq!Oq@)txVWrtlji_G4)Z=BFk+H4iQ=vQRC?IoLH8E&0<@Yrk# zkRZpIT)Bx3AEGoK%#Ddokqw;Urp5OTl#+HDeBXRl8QuSF#0hmD6`YLrOo*% zOSCqFW1KqV8IjN&&MB=chbUrqOJX}Q3OP(JaZzG)#M(%pDT0H7H*b<@rpjb^rAkL> zmeSg-gj>0pB8bybB@6_$)2Y&Jjr=nJr-xtPg&|2=N zWt)InEt*YoP3Ny;if2mO$YUnaP=&T+L~8&4Y5{)qH^18Q*m)NaPbvRKqkzLu1~&8q zHzL&jbhfcH#`xmTKLYJ&^y2ee&Bs3@WA1HAqW} zL(%3)EJh-IlnZSywRP1t&CvDLo*Ycbek?WB@+_ZF!c2&%de%+zvMfJgiL+@6tD1Rg zjfksqrGoQ)=)*Jbxo2^K|7g{dD1A9MC+Q{ z9h5T8idNhwXxitUiNR5C;ktNefYT%v1c1B+tr+e$p4Lf#5W^>lIc8%t<7dU1=m)n- zMW;`oJoL=5nGNUDx9rq{)pbit?tM-7u>~Gex=zDQH4<0ma5%#p1O_1llGh8l=`3js zEPMlW(c(AmrmVgVgm?T&bg-94mL7;j)dl&|-|GMESl*ufxjpz^-I?SBBUh$c95Lg< zN*IW4LD*4oBPPX3*O;~>cQ%~FJ=e^ROTX1!#Yr#Ayw>0M#j+xyfTa=JLyAXp?i5XO z@7O%TsxM@*JIUkaGF?-bXT#WDltLY&bDG4^* zU`|RbaiEt>E?u``gp45(wJ}jNTzxH9V8d*mp?7UC0{SuuE+T8BH~wFWYSvgtJ}2aN zV$DG96U!(_N%8D*{h(kBAAQd6$%M}77aQQ~V&y9*7~KhLndnrJ3}qLI$xEw^6N_!g zAdpbh5tJHkZSQI>zux64`~BsK2YA{B6^$iJcrhkR`dyrAD)PGxP2u$t2tqGQy?yCc zek9I*_>ON5HLY^8yNl;TAD{zo!j=|bm`q*Uce-7*EAw6Gx1%g;EJ*YeEo;>SxXyX3L|nhs)E~X;iqq7E*RKX}qE}(zs|_ z-BJ$WubUdAQE{oacB1I6b_yOXY~peg$_}G+$eWt6?dGH%rHoT8?@Ait-K)goEU`a* z&lV3%e{91hWx2C?3m}_krJiL~vk=NGhD1CP$$0NY%s8g(ormp11pQV|1e%H9f{Eo~ z0Uewuo~^%lM0zA$%7mf)+D>g2Gv?B`+a_Of9fQd6G?0Y);3;J8apvGQL*_CDYcKB5 zkdY&p>*UB%`kd0s2FC)iilNoe1VH-s=}&mWr!vdG-$R1E3gm)orE_SbRa2(@nL#a( zHQ&^Q=^FFH{@|_LbI)|bi^k_733KxAiho|4f-uk)CrS(O*~cr@X>UxOJ^Hg8 z`)EzgprEB-Vq;L*#FMSD5ySD4uk+;5AVJB^^R8ubC=_q46SIieX+*NC*Ysl&8MA9{ffE zDK9usSah^w;g!BDZh}fVzNfRdSB21XC-06~VbkgATlP*rD+jc9Zy=b#&Ac9Uc?X6J zTyqtVJNpC1bqk7`?kup(j$-87e;sGcAEB`JyWpv zZkrJ3Bl`(?NE$y)&c9k+2ru$2JJS!gpQ;JFYpLb{3IRoMq73BuK<=21*ip%YM{837 z87p0WR#F%WuD?A^-{~QnMdZQi{_kqq)}QO-7CA2XczM``z1tKqDISvTx+=^Udn;#x zKI;`OJ0^mNZsrz4lBm6s4!R?eKc)jP1lSvzA7dl z11S3u4xfiwC9_V8CYiy*T>Nsf{>#r`)TL3X5q;Wez12m|#CWoH-1^UHOvc0S@CxlN zL3C}gw95CuQJG%_NQ(3;oZgSAJ~Z}y2(}DmS3U3&gw<~^k^Tuvhait5C{+hi&5DqP zcO{mfTkF%Tzrgu^egkjK_A{&)ON%XlLXP%qaqH0xo^c%kwPYu2Hv=GoxIf004sM+0 zFHCwkR72^}ET;W0XNCjQMSJBbyn+q6bB*p}_oX!c3ZX5aQN$~TL&2gs@QGR`O)8TI zW7)H!;Yw!5b%)bZO?hh~4ty{SM1d=GsX*GF*Gih5ntDHX-?054`~n7N!40P(w7UG* zmVbyX{x^iS4EzrgqViG^UrJ`6#c7&2<6ps#6>GD9inz&;Z20$?Wmy1BTqx!5d6hF2 z+9-Q=$LMK~{M7KtXn=U`P>CFI^ zwP2kbP914j2TO*;T>gX6J!*VmdRa48uhFNhRLJPJFxny&*?i-m{_-+u*|3GZPx2Fg z_?v<3{yp;H^!pFW@+?{Nqj6NnC8Ks6<5pj;+M8US@-9*`L$K)te*HgGon=^7;kvCU z>F(|ZmF{kkMx-00l#uT320@VS?hrvrO6l$n>4uMn^Rm|7=UnHHe-wfFy>red#<&NB zz@=^d@&rcA^YTz6(xQ@Tp91gC`>yF9T|4(0kYw8;FwD@;x11jw^{eh_5gg7FRC>t6 zJZE@7I^mvtEB$XTssD6$T=J_&d7^EyMd6QEwitO%-4f)sMmjA*b!ebt%YsjXVCU-q z;Vr5FJ?-o2PlJ~pDUI54+(W3?GF1w>20d+H?rHtnW?}Nv8<|19GU;~)j&aI7$S^on z_$p16a!CDam2->>n|Y&PFpJWM@sV?kXGeG|tNT5ncVCWz6Qt;8V>DD37=cM;nwptF zuCtk)^GQ3{b5hz z)2H|+Sui{Qt?c1?6B#6833ynXT@BMQF(9~-{-E~Kn7RZ@h^08C9fZPp-TK|t zzO3s``I~5)?%L*J%d;6&eqO_M|Ap#qjBfvEyMg3uXgj;Mw~7-Hpr$b zcCtFt%a0Y+OIdzWS==XG{J3{jbj&qem zV2DQ@mQB9B&6Hh^d^KL8qo>fVe>RTznPXQ@D=IpJSxdS|s+jpkchXSdSg%93*W6>J1e#*}5uUg-g z_Bk5mtRPJEiasW;VWHwnMOICV)GOpEldg4{fs|0j-g8-vnj15)m-&v&$mMOUj zhl#B~00%0$T_sWEcfm3D957BMic8E3!VTc6qnk#w9di`Eq>b>zia<&iFy86)i>J4y zcdleZZ$xoED|mPMy%xNEZ0!nF4F9Cm*Pnf;Dyp4PPH%AF;aTOJTK_{hO>h-;Ii2a# zTf5xhM=)%Y66*@Tk0vGHaE^>k7cOzIKhS>uH*9!Fnj}aL)CLhEfs`i&(NQx*KolQMb<`Q%=8EXC{)#lu@NVzp;~eBf*ZE}by}n6d zS)*#RyY=B*3s$iTx%4?MllU6Ff#+Z}epLpkz9{<9Vt`|iAd;sGhbf~PG265lKr_5r z{94$uc|1l4O9`AX&>TG41GEra(Q3iVonsLTAMw~ABqkt|X!MXXzXSvjPm4aHe>}Ve z%Hc^bAELRZXz9jF?Y(j z>~er1XTnI>l8l2Dp=T3g5z2C@IM~HU|4YQ`s_>9u7IkX)T1EIXKdTi#DJ-&5?vem&jQR7Y0>rBtR=d@OET zfZUFx9TneYQC;;E4Phm;8D_*~XVGtgriLHQ&vbOCpNX`>ix)l&tyt2d?cov8mg6ls zgebSwVg$;W2Z3G@&zUUOCB8=0iP-$#QA3e>|2qe4VE_Eq~oS~n#pY>rXXG}-IR zJT8t7C(5W~{Lm*?QwW%Uugr!+D{oA+b4nh&6~#dzRYV-`k5(@yw|HVlg(LT(p?O6i z^mp{Y&Zni^4`vp~IdPVqVk|bo%(uq{=sB`#V@a$Tz|G`SSzNkFS^}LeRsc|!vc;{~ zwAA0Nn7N>Kt<1Q1Zhds%5?!ndUZgjQ!{(V}o8P5BrH9gU2Bs=WU8etH-(e%8hsI< z6~W|CHGid*PvnI^I=+(psKgHYEnS}}x*-{t%9$trPU-~v>=2s%VG;WNZMDn_ss&rk zh>-caO3?-7wu5oCP4KqQ2euH$Z}5ig-T`s=_`j#`Efc2R7E0gS4E@TkWqB~$iq-XZ zZ6x09Yftpl`?aQmx;KgTdcpN~T?d*oi6?A=-4c3#tmF=Rb5dRP68GxWOf_hqH!YUY zCz|VBYaActE%e$vcnehmq&sy`^&_0PPd9o|u(rVRNv9UW_@CA%zTo4& zzkUddeQiT6TO~ACbMPKj;rO}fQ)gm`U#>z-ixpW(9IIay#l<)Ku2D%n6QC*jy;>w? z-T59CjDhhBMdg0VD=58wp-Wd+(jU_upW?e@A`ZBXal*Mk8l;KbS=TC% z>5=4bU$bUOTOt}^f;AlnCgBw=iZ};yB)Y`PLuaShnef587r^C)8!>v{s{=Wbsd`J5 z^6w__xIg8Im~i!Io4JtFTHeu4X=@Y4Ni*uK`E~`)VugPYbUTIa<8obYg|%JVbuO4hPYnL53E^X<3{LXpXt@zdo);A z+~M*)8M7?hMu>#3m1;uXI{+XjL4q_Y<=TC=j1^GRuMoDh&SDUcTaILO0ci=GN?Yw8 zN;rI(7hTU)g=?^ATaenpcbeThbTG>hhmuRv_&0Nv+UDZ`p3D%zLz3C+leCW0$CTMv z4kyTT+iAbM{RgNoZK(lolx(O&ln+@}a^De*@GuVTJW_T0^b9~oxd0$9%o zG{w8a9|l%>)dg>6;{5)N_KdYJkNfS;-qeh$3VA_rXC)-#`E6an`JcLttgl6=qL?K3 z+-GEj5Gl~GH+k@u1k88U!}`@=CyQ13f!$;LWf3}-2DC{pNR+PthpAK`z0Ij0z&|6* zsdds0)*jTY!>O9VomiZ*(b=|p2Ns`6VUTrgsND9alAVV8{0#5G7#iA$ZVroGV=W>n{b&bz7s@EqNL_&Cz7ZKMjA z0D}Wg15GFR7f45R2Le&&9)4&V>Or{RKg&a_)?zg!>5bo_qBrO&F*Dgyq$YY(H?L{i z$oegjXzP_4ArzqYvzanpXTM)mv+S#aMxy?kUr?W~838Sq^2a|Ef{<0IW~WIS2M=Xin zguSLjSB39}z$H92)BS)Z=YrE6yXzlsfTJIWWiWO=oMT_V$&V)Dga_SzgrmtoQ=~^W zoisaHBEU#NGUV6RX$muybggEHx$%+c*~heG$`7plL>w=7Zmz;=3As&}pA0(Gy$hx) z774rCzlZw^-MTorHTyo)8I!mUlc7%V0F7rnx`$2RC`npjU|^us9)^0{YyD=`uj~n? zrOMc|N0AU>8jO*?I=l&a_J3-EtEO z<1@qfsx%Rq-^POLgPTe2d=?^Q!f1y_Z9AdBHs;f8d*rZwCh;4pk+5AdsSVDD$zRX$ zM|*bSU|cNdN9?!tY7Jx4FbT9B;p|JlJT%Pdd+Dq6 zMVm#oNA2g}zL)cMewT9=N=2!2NK>zLcD`3*^n5me!N^__Kl{Wj4X^&aKxMg?Rj#)}0S(Tm4Rw&*i?Jp4OS*BB&&r6D?8Z^=-Ljt@F)4mjmqlH5f3*?u z2~}j&twTjWoQLKJr&US7K5)yB>KoE=aqR!;`@m%{h+cgXrqb^cVD!*)`YxUj35OP2 z^!CStu!(o~^UBR%vu(#S94f5ffZ@(&RL9)W*8(U`9pcS5I0YB%+8 zrB0}iF3mdmjW&iU`!Q8mc1yv2e4Ajy&6B_#vB$eRTbp!8PknJ#S8_0E-^~Kt?Y1FKI`yv++5dd<(VKV4GKG zvm?~r;h#`*Nq$BxjOo(H1e=HvW#8lQY%K@D-={wwgxtRjmYNxF%0H@8O}1Iw6y%F= zT3@15tA&XUEyrfx&3>6{qc)%;Aojh5qfSDo3*AR*u@M?_xD*r1O(B-!vd z*ey=-EkqN$%wexsnrL#fv#Z2uEsQ*_ooq&@!Su@+UV*9= z0U8Tlr!EC9gzFrbv2`kzZ^}MP-^Yl(dc=x${&21f3pXc6M<? z(?G;0%X@$Ev2k$i@^Ah|UGdNjLgPNScwjJ+;|@g0*~$=>U^H7i?2Rz{`C@ivCGqf) zAFn+Y&+{VmtM4`U9wSoJ$81`aTTk-XYkZ-TN%S^GDuJ@#$ebi~^2=KGRPs5!Abn zgjz%CAgOq7_c7W#kIh&-XVdU<}ScOQVzB0v!*H&zLeCt-3@zYm+UNPDdf_@Cw) ztdj;$04}=oJ2z;X#EO;QoUG81CGKx+9CbHi`N~5{!pAfI+nINg9Hy|P&a=R#70=TW zMe`$h>d%faeDo&Myx70LiEzB?(ji6Fv2BXFIm6B|%(V*-b4L`w3|~3A!eO!%GHwf! zVt={U8)fRU$LNtmFtdQh;bjeDcso6_J$*YTr^hf=*Rgwq>}C7p?~#H(yFT1U2oJk zEA{-~8=vie-E#44&|h|5mqbOAB}8zYh%(+Q2tTO9&e%|&@l9a=_3VWEW{T2qF1;A3 z!{ux@K_M;*^$6Wg4MKn@D{*%cnAm8aeIK+J@48TBk1M17bnmj6BwuCr3<+v3H9~k7F0snm>vdC>Hl8tP{2~Q+!V`Ho12X zkNjxwQM7(t9PW8`$UQh%zWTYX<~@v+Cyx2KiU#77HFBWn-l*D!Ct8iGaOpop>?9F8 zP*A_@giA>5}}S z>z8&>8t8inhCp?r4@-D*EW5=8P1_Yi&hI|5ddLY0{rx9@+jK zaT8E=Uu}UHajaL;t6nF^pW?&uGYphBhPu$vVcges@b93xs2HqzI%~B@M zC++QcZN|31Mw7GWi6f!awuxi!&~VflyG;z3qQTMwpfui=Rp9b@1>!6bm4 z76ZAvI;I+{#S`1-tSM*j^%uJdj879S$9;cBS04c(O~8#r_-o=4uonKkMoBCW%fMqP zFP!qfS^z*&VFQOY3}6>AofinBSo``eU5aX3)#aWCw;14sU%|K3S@vJzd7d;lm=g+e z0fz|pBA-T>TercssC8uF8p0o*iuS@}cJ2T3vto^{3 zPD;A2Csyq%tlqd_I4=1aZWJ$hR|z>hLQ{E&GP;vLj_qRnNa8D+TjhG%f)P%HH)nnd zhqIkm2ebJaUq>_eaSSZ~G)uQn5+U6IOZz%d{ErRab^T%e_;)0oS7i$N>u<(#;53R7 zjj_F-ucqtI@ZjALXc^-9HOahf7bEWSk4D^EGS{8%ZomG5`Hfy39?A{K!+&-WA%$%h-HG0Rg# zTIIv1-Q81Jm0I^Cw>hmMXdzEeh>$g_u2=!!Yvx*O!202R-U63AtxZbnWcbQdk3BwF znH3(c(QbJm$*!XpG=__WtkhOLVvJ6|eN=!dz!XZN4=SO)YZYvidQUsp-bFkS-`Oda z#)#a1CRmtZDUEu(-EyQky>QD{tuoadGH9U-wXu9Y<7_9ML9iIm>U-E;ToNhHk`eCF zI)FK)-6v1HE9^NT3re7_wS}cYq;7*xGYooXuGIYH1Xto1etn#C@ZRYz0W~KPtNC)@ z?FY+TH$o?++(%Tb1``L7`e(T~-D*Q&M~kKOq(}L3LQV~KzOK=eEAp>MaHzj%***jc}I(rCX@FZw5p);WnsOmIZEH1AYtmy zT^kw3B+dSi4U$P=ROcuj3KH6Thdjo3U7nq${;$q8Jp&?~Z!W5!q}0%@Sm2R*&E;X!~oysDj_M${>T6 zNu6*7#jGzHdRC&!V=3oAT3|?#g~zn;@xaV{)5Vi0g*faL8hGC-q3Io#voc*!x`k>% zG+Pq?Wai>>6rN(<+qg(BwOxuVOxxtVyH^Nv96R>FQzx{rti;(^TpXA_ z^7-)~#*`6}9+WqcOFsol_nagojE=SIxR46|qIeAT4vD)%uRmxPI4`&=?&xdLsIHHL z8T1ZN-Zpic5UBI3{LGGcJkUl7TXMO-liVuae8mWP^Z*@|8mp}m?&t?tU{HMj)m;}4 z-G9CmNb>Ri!dz|=bB5HUU@`69C?-M=Bp$AI7^#{jT0~CecrzqK#!Bc80W+R-_J9D- zf7kKv@>6jp`Uj)w=$YjF{0pD1c6S=r@k=C)to<(nFM@urf~khF8t?VVDnTmr4sj5@ zB9Yx!bztI87o*MlD^+9?p81#{&k-v7F6zNcK(0s9>av{1aC8?LP4f zl4v~LBPvBcVN||Dml)^?4H?dehuJ@CRJOYNAhfyKO4%_t>_C`Rv z&V}*@%bbG&`#m+=?`x|3+pnp!P-s(N3K6NODns}k2S)bxg2d*AjiwRAdu_P6RZ|S% zQ;Bl*k8;O`NLaH@-in#{h!|#*o8Roi^Eq3z{i%8=I}_MaC>H*rh_QLy$POePv&T0F z2h*lT+gJg7)j^JgUpN0kri8*A_u+`5|0_Mtn81;gMAsvnr7o|BtLER#@K zbB8Pw8I$eQc(IAR+{UQXX%kaK2*9E)F&54fjccJ-i(lK7yJ!8C&-ptx-$s}(p+@d= z4N3lX-bxSI1r1Dz$AvY>@Qmo~>(Yq;&e$2IO9o~cf)2`?4U!s0^-9FWpMg@w= z@aeo2xKQdT6;IkfE@%_!xwCI~%C3p(9nZUi>NWh)14ijH!Hii!~{ED5Uxzk+J`Nw9s#GW)^$((6%dn^BT zj}aPVSV6)y6M1v^%-4-gUk5geUgq~Nj-Ep*VjX8FK+yOAB~kP{FS8Dh&zC^^lC|==~Tp6L;8Ln(m+h2b<-hqOuMG<26=_UxOw^Jt!m`f-$iq0v_G5= zHd{~i)sJR_98y@qu@Tk+RH9FxI_1M;Mf@TksmdK0&0og`KcN&6lDtcK@1gp}5(oE> zwjm$Sm%5;N9KPIu6WlBI^2d*m_r z73u3-7^l(;5hZ=!?kPoEYh2a%vUyd43D!6?4k^p5bX7zmec$r6*?8`FYU2qfdo*<8 zPL51_>@j)b{%R!Xf3^iRq!zRgBxg3PcIoxK99DJ?Sw04fwoW5nx!4%}CCSsm+!Gta zqdhmhS6I-VIb>PTt}>4uyU>eNTS(!MZ|R2o*`cp9 z!w?ssbixnZ@k;#ESBj@mGU>+v7~-{&yO!1k|7=;WUcyh z?m$4}j-v61#{9#;N}Ja!G6pgVZ;J28XcPj1&9zaL?deBt<(viZ`fa{!ud|&iqXCyb zvm)m}kj?!km2NR$ChY3_Mt`Rz&w&Q5N~cm}H>UhwiPrtyX(ej59i{s+`hkB|jDb$x zUG4iqEG&7y2ixkZk$c4J+dY_%$`d^%auUBSYDI?^`AF~USs4Aathx{UEDaSDP+s98 zFI&7441tEG@zu7JGu+L5KDs7YQ|6ji31 zk@WkE-#TT8Y!OXb5uX}27YYTy${6hLBcaq`m68WpkKZ86Zj&@*k`#A#yGwaLnuM$|^ zDTFobFvYOpWhi6ZQmGCT+pEiY^l$iW8IA%IL$Q7eadDVc&f=1O^X+M!yQSa!smm}d z22ne4(S-l3Rj0AhFdzy%79XHp(A!v802J_3u z50K=eG4Bs|vKbp{B{wrCCJQi=QB21CpwTAqiLFO#kWa2(~`*BJAe(l~Wb9wo#NWYZPpA6p#uXzHV{lHTG}4S<>9KIWvE}1^OC!hq#KzA){l(x zRc*cJZGyE^lN@Q397nTR^sUomv^h_^H%*z@2i>93O6X4r z7;IRmI&vPSQxY7)ucbE{ZraS)v}yioZLLqX^0}B1K;Yl@udLxU?oZg<`b4%Zh=r za0whAxuzr(jst148PwvW-XG35T{s->|J!A;yt5mA$W-N#H7ig`=~W^g*ZKim54i`2 zpYmMQZ_U$HX^mXr%-BJXC~f1OGCQh53bGpdSaA9%VF4LKD>y5{M>bp(IUZUaQZh zFu&;F@Nl5L&Z)`(xh*JZ2VCPa5wrV9bcD2Ase6S$`DA9zI*o&UF8gW!f;uxXf=tz8 zO!YIlweVLvh4D{!nQkjs?_J88_S4-yp67Mj_*I$ww&Qz5o5|vs5>oj04U1_rxi`kr zR@xv343}RRA2%ZebqTDc2ae{=P)QI>)s_a*(JT_kj*j}abm3B^8dDA)vc-pku+Y6C zbifFSgoj5beNl;10&a+LOvOZ@*~i?8Vi^TN4wzbB?03_Qj#)1vRiWj+1rF;&Im_9m z*Q(W7P6&>-$A7I(cV%u~jq8A6#|y{_nmid|zqSPJZhTN-G-hjJVy>J<^7+E9Z+dZ5 zwNa5A;}`Zds>gj7%Be~=;?>~Jmr*F894&nD6yu3dG?F~oAx6ox+^v(76KJix&|N#U z!;W9bMwOqLL&XJzy-1m_=>X5^`63q<_UYRfgmeRu_<;+O%GHY=T z25eW%H7x6cH-=ml6cvs=>}MJs?Yg3-tyBjPxbmy+foLD%D^+mN2u0cV@d*QW+zvK8 z@`SZra!CMjA?T^m(#&AY{^tIauuJhmwtrA=}40uRtwC)rcWsbp`YC zkjQF|uM-Sf2VpD6>BMnvwrhZ_kXOYwD6L&)^oRT5$Et7l>3TQ2X%FJbhw$`jAd3$=!btQiO zQGH&EXE})wqB&i8E1&T3{?P{gM~bs-kAD331wq1`!yN{#aw;ltG?hy~iZ~AS zE{5ep7dr8x-CqDJTdwB{`U@oE{G;l-)(+XD8C z)qGuH%|St1aa4fUIwiOPLemB8NFAPp#@dBd%TqQHYR)cEPo2*1AiI5Ofh?SHfXuXc zquf>@Ju#s;H-!aYZ#}$9yqo1|UTWpu(B|^hVpz66?*6ZT>$Uxn>Qx&NTKqsDdK}0U zT3cwe4JUY;<>1W7Ca*-`M113@=k|&3O0ceAC@4W`rqr~?<~un5YOo|K&u0VecW*;| zdMju?Qh?QS%lYuURwJ;9eY7yr(aZvXtTM(K*r9o?hDzL2T;4X%{rSX!Dxf*j{`sYF z@-3KGJavMd85~|5t~%?$P(GLS)YH;TdgtR$V6C!Ns-gg;fb>bnJAeSHN;_|~=K=h5)dhh4 zKmu|E>8t%Rp4NK181Q;5)Im9#u*={^`I7)k$b%YAvFZrh>~N2XPevR>Mu0EPt4pV% z2TiVtYPHuK>zz|b~fF-lO%NhM!;w8~bQtlN2%81>6*jll(oxt5(vwNBj z(H!e55sn{@=B zyI8dT<_O}uUv~RZ_-!lMK#a2pBV;o!QAy#b_6r=@*(`n00Rk2O>z#kNUmqt02nR?R z$EyMs{xkYsWp@wp%a>q&{uq+=$Ikgl^km-bScy0E=h10y@s zEoWJtDa7hFO4G%W))@zW)*zhOV?}0e8BsCA4{=#ftD%@bTfh=X>*jnzr{Z2=LMf!o=Iri%mDr4c9?a)lSPvRxU2qOHIBScKiJ@byl$tEa8Tj+I$9 zP?a9OC<1_lMcNvuV`r*`r* zoNZkoQI+i*>x=<38xk%}fBCqqWxrSSSD5qik^hvH@C%`Ltc}4au|%ty{Ndq2ibwtS zym~$k2z^5UcPr0f{WqV-dBFwz`sQmE%)cQxYrkKuE!qc||Am^lQ^*uJ&g8_-gi0)^ z8cfRN()~tCD@IEPv0Yo%NppO5BrI}k4p(GCBjSG1^`%7bsdCxl-;+SjIaM-RV@ZA- z=UGOX8u}JibqJj|;wYNBVn?B<%2f#2ZbWuI4M&>ew36$=Y%p#4i^ z%GGp9L0!4tk2gT0xpvxB>k11^D{r6$#4uPy!@9)^q~*}S7FVS*d6j=_f7T=a?Cgcy z0T)rhn-g185>S3Uip*v`{I}Gd7%#VzmRsv9ynar=Oa7cWRipJJJXoi?ImcCS{>h04 zz9ut((}($_zTy62bv+YkoZAq6hB}yq)hvU`6zs3u1aAlFf8zt2qLHE6hzV<0d$TvK z*!SEHG!N9CLxVw=e;?WNHRPNn8Qv~;ZG0xcoXcr=jPbfnLzy9$NGJAC<)h5}W>k9r z<}D*5qX1_FQyXreJlMBM{7s}==4G3%ALJ6d3l<+_4kG#Hwn7}@S}5wIvsy+|2yH(L z_47E*3^l%a=^2y*V`heiFQ-rHdj7W=cDq)o$->HmU+L37)&=&_1h)XgPVW80$pC{}%(3rp898u_Sn2hp%je z+&#ZWvO8P#^_xMO$lD!oUyU1200muwh>vHjHpnX0$cFup|0NavX+aRxkAj?Lx1s%W zyQDOhfMR_amC+jw@zV2JA90oL_g|nm z;Dpi?HuBz8>Mctek zjCTS&KYTAUvMFZCGrHd%DN}iU4;*P17MXr3T$D8mkk>fFew)B7507;^TT`=;eLp3o zN^sze22xc}f#ggofm*r-ziZi z=6eYAWHBuN7pQDMyH!oNRsPUhgpR&gl+fQ9ku4)4!n%UCkf*mbr)!MPjDj#^1L$#% zH)l~lYRwN`HQ2xJ)8!Poeth6Y7;ljw#~sdG*{NF2B=BDAga-Ux7$9}R?cQBzWRZy> z4gs}GtCRMdqkCZ0Y6gR6lI+t{hXTbkbwa$NZ18?)agL7NKYXZ!C2Xibrp^}a1T8u2MEW1?tO7Wheep9!UpP)tYK)u&uyJ1*+4ePTnx+S%=m z44Px;7QqEqKDd}6#no8>4}*t{=t^lP2wsp&oyFk4g@ERty!Lvcw8;iwEF_KVq#q)N zYwd#nn&WU3&&5EW+vfhZACysQ*8gVWp5Cqe{+WKy-o9;zJM_NGHKd%g!e(u$5;paR z9~{Jh=WtY~1G4;OmD1qLs;($?*0Su)q~ONI>EH5fa6_z53ch%UnQ)3FBWd?~RD2T| zlv#Y4Khl+JD3U?t_&(kRCb|NmGNGVr#n@j^V5`=eO0T1|e%X$kNW$S-bH$7p=>RV8 z4RAwJGPpQ6umQ{8q=xMC|7sM7&7GkcXru>i<;?;4pfdWiVq%ZU}F>DXhHl^L6xjG9J zAeyBY0l^U1v!BXo<>JLENVW%(C13c2qktz%^&jcR_QWQ`(;|}I=rEp(4*X2o&P2PW zWmawgKph?~xANW*%>!zAP4OYtg3dWn!l?XAh0c++!jGNQAJw=nnVfd1plqH6)Z0(9 zY5Ai$K3Cv$2Z_ftR3p;9Y!Szock_gew82N4NmF8{?G*5J;?jt5U*KBVq*2sQ<2hv} z&VJ@cPu9D^vFe%Zl4Q|XZu}*$MVD1U+i#;Q&-_3c4gJg&a(b+{nxk%aqA!}&5#Ezl z)5hoMLlSGW7^=bRO1+Wha~}tAfs{s}U;ccnWj+7M81+CjNcEZHWWjIO_uR7+3Iu7q zt@Dk*BL$4NUZZgz$7%)7c`$Q|@TtJ=x9NOMi9P;%x?IZ)NYj7sZp3V(0e%uCh(nVK zCdltW!XKxMq7ZFwwh5GnOWu{D{UcTLB4>l~7I55vdr7+?^^ayX#i|4WDO4`|Q?b*j zWR=U!COLIpw-cw6Gf+)RmLE;eQ?8{(rR(2I#tkIlxnz=Eiytrv6KoN@52zOiqjE{) zt;hQ>Wijm3Zn91;fwgw#V;iLTG`u*q?TaXU!`q-kA`SN=7th}z?0Jb|*czTpTOK>~ z{A$#rZ@x1h;9?3<4@Z{*)Taa-eyZ<8fqW3aYX(fw^-)w&#E}bgc)gPslsK&fEzF;< z@m$OhGO^dp`DyOU#DkkTTZs?NCppc(-Zt$7TMXj?0 zJ|xW${xhfrEx1I$XcZktu^u3-*xZBQRrh|I!7yX~PxnR%+=AQ$+Fk(?tKGo&UR`o< zs96L%qH%wbsu2D8@dW^QRx9H0zgmEghveaRA)afS`yKpZ_hk`tTn-Nl_(Ln#uapgW zD<~=z(0f|#V?JT; z%Seg%1tO@}(PU&7hh}2{MOd-Q-L}H*EY3xfJJ0BfUkO8>BXP-x%I`J9(e(-8u0BS{vGE&8SiU%}Wi+H29);&cFTVHU!%0hU)GS4=@3zGEMaT^7?2%B2kWOO0V&2nVQ2oil2oIqI+O4tt~=3!@i5`Q25M>iT6`7kG`BwHq^9?yo7LUAhy zO^?iJ1lV37aUyqeys_t9gXPpflHS1b$iL!oV5`!aY4b~Y3`%?gOkKjMSNi5<0JIwO z;)4hc>O`$(>bh*Yf(Xs|4#1)2*FVp_$**}abznAJ(W(JsXm0i$s*2k1Pm%5KsxahqKf|vtO)T`-X@PIwc60GpK1fd}j z7kidBpbmobsl&du5jz?+IuLET+Wha-K;FDh(VAduCH)^wn{1mqy?;TkEb9FQwN4ol zWvSJN(Zy`12T*x;C(_{5ocF}B*TrbVMnb88(R6Qr)_1+Ph9a`NkDSfkJA8q+KfUxH zd;y&DR1kUbv+6GF6w9SZBExybV_2M`$2Wr_)ZNqTYl3;X0y7H!F9nR(og5eyS|vTT z#0~=CJ&H8vI8Wd?_?xJoglTe!)Qic!Nrqt^yarD1Pdb98q`}NWw$=40WBG{WD2^Nr z4gnVRLz%Xk*R`nKK`dIaF98Ew$}LRhfKTD0xl4foF&}0fJk2bZ%yQt#W)d;{vR>lZ z1v<5TtRiF9s+)#wTxMG+3YE!CLr@u(#cZ(ChCC~eoFzh6LNF2zGNtucj=1vSxT6w| z(8fjwXeM@L_A1t8+CboLPQn`Ne61M*h&K)dh-Dc!fH6TE6Ei8QXf?mBJz@9J65UeN zdv8ufr?U0Ls^_*Rab@f96hk3uC?z5bL~6;H`W6j%0W|8I0`IhCBV4CJISvDlsFupD zT&$Zd;*&95DY*zrF3IWA(pp%ma`rDuI8!V)BA&K2)#uwo!KSxh^MA;qG&xv^;8Euz zEq!}_C%zVDJI#>o9S9yei`@3Wh8OyhtC4XJ_`%3tni0NS|LOKkSF2Pb^}Q%43S==KQcVnBau$Oi^~__saGK-A@#&g+9v8y!RL{ zk$d~~kD(qMGXFbmt%&wDkspvEMDZxk^fdxgxN*IAcuJX;kNOzR zaiW3~LG6@lHJ(RJs2*X)z4jh60zDR7Nu++d$){za!hs6fKpiRizn_Qi*HNvPCt6;%xe{`@q>;O2K9<=3FER zu(eXONcl1bOk3j?uKi$bOTVkJN6UaI+aFY*?K16tEc!FT_xUT3GTl#QLR7bxe&}vyApr1qdxgR<_bYs)v<-Cj_kIW2FiMCKEW%04$idl1d%| znU2Ss#`sRIja(kHBiKvZVa9Gs9Y+@ar^d2i8f?w<4y z?#^z@1o(*W-u};2U}je7@_c;Bk2G}tzN52+vinYrTSXAHto#{xCqKOmU#)f+LBav{ zls6u%E;YG+^K<7rQnck@ZiZr0l~WEOKdcaMqQP9F%8q!N7?EoVLy^K|Uk}tz=9Cto zjuI2x0sdCGbe_t&biO`{Z)2-!6{y1te)0=JE{=h6!>%NIK#k3PQ(Cf{sKBsZaz&khuQ@D#Kw?RGX62#Q+96U*- z3>D2Cxx~aJaR)hK$#{iJ%rI3fd}$0Z?B7q)V7uyg22~^dP2IYJSGwAcM*uu{ThnEH zDELkjFLfHO<9VEtlS8TYE?QKFCU|@W5Nz=p;ytmzlJHvDyQwZ^Z7_Sz=IyyF_d$3*cVyFyVi8*3vgOzV@_xM2LyZ(4&xm*uL~5> z6*bgSaxoA`edd4nyE_1j#y<>0#~%-1%OF{)S@My1$dh`0)^5mP>EJSR`o5K(*toWv zYOp*PWB}YW{mebh#n-7j5(>1J%8_?DO2>MAw11nbxu`-->BP}cQS`4k>m=!HuGwEf zK&1^@eyHRH0S+jl5x^!D?&&%tHjz}JT?_MKZv<)We-su#;7GUaoleV&35!jDw2o-$ z8Nxn37p~)n4LxE?E2%QR1=b8dpiG-t=9GZVl2#(B=w>Yhuv#;KDh^bhT*jVkKL>lU zZO7{x@1kcQBs2SSzR#Fg$lS-XT6NoEZDfERCAMkP{?Uj|cF@Mg*Cvb6PGoGZeS#*w zU$^OgENDK=jCFxw*Vg++#?QO*kXMcHTylIfdY0`kG+hW)z!iR)w45ZY7ymd|SW7;VeEZ|F{O zH~KF!hwv?TS7&&1?1LYwG6&6cAshs$|C64*=?|sO4!x>G@{1|TWhZk%}!o$n##qM@OB z%4SsH=t)QjhUCFA&!~}N?B{RodI5u)ZH~N6|Gv_m_Af5bck=fG6>dG`#7W%=I|`A) zuM)SO&i)-yUapPhO7=SGwtN39oBasJFGzprfN~56x?n4m>jSsyyj(T>)|1IiFQOM# zv~Lu*6}8`0JBrHRsIR8we+*&e)oMpyc`x@OAgm?Kli=vu{pV+<==_HekO@I7T>o4zQTJ0Z!rfp}xi{v#q-d$f2#J@3s=PI$u+ZsB|I@$^JVX&io(qs*7;x!V5a7IlDV|yXf zCTRHAIDp_rYQk=*V3=3PbC0kaKcc-1R(^R%r!~aI#SQm8+~;D{Y;ROs>HRZ9t+F6O zY8ZxR5LCuTb9gocgzq^~BC#)U`aAyqPzh&BbCe&7aLNVff)NJJB7H1ACOaxb&k~*7 zZftXU*Gyo<(VUpLYWd-wo^pD9pJ?aMvp5gL9rG84|Ia5Vy=EvCt1596N6bC2Me;j1 zoL-4*$7wH^qx4PdXUc1g?zn%GjnjZAz}Ls46R{9Nx1^!a-UA_&yYrQ0AVtzfStZ>D z<2$4n$Q!Cr*weoYW$yToh_ByabbGm||9dbw@%%D#<^!ZUFlaFi+*192vPDG!x-q<@ zq{Mt1&Go}0^F^HkW`r_KVpdW|o{?xj&4V4_EK?59ikCIgQbOH4c~NT6qZMR^g*|0u z|7wnVisGi&I;{FxTVr1%EeUemL;aiBKl~q_zACKBXlqvl5s+>W=>`R)yQRCOJEgn3 zyGt7BZb=F04(Uz_>1J^zd++~Tc&$EbeshlTX4beH?3*|6MITYP2H!CCj64CZiI|ab zqBtORC%E^J&rV25QKsngFe?A;I&_4i!I)@k)u*nJCun&lA0=svhMr6*jE2J8bV?cb zqH=MAH;Uf_e$Mjr zG&7C%-}Y;BM{2&uew4?H*2V|<->n_9|46v&ZIEIkA{fKn&bQGr5zFTsfZ-eUgAEfC zQ!0lYDW!C3Y`L%xx*O3`irEoe%!u2mOZMAAcON(AC|GChC#m&)3>_oIvSaKr(M2`v zGFhF}Nc`#|f%s*9bbie*$%A%dMC?Lx21mhMpU zvwr9Hr|0I|4R@mVwq8OdREK6cp8;yHyffkeM|w4a1_?7oy#Nr)#D7BwNX8a1>;2GN zFdT>lorTb?O*2>pC6t@lR);5SF)Ebt8DCvcDriDPPPaF=q`rz?CAaMp9*Z7|A9;%p zx%Wi5=N?|#-Y`H`jYiXF_cP?H^#qq&9K=o?r&h@x+cEp+e3$uGy}c#l#h9u}Q&%Ue zvwzoQw;4vQ626wW$=2!bFB=LY9=SP2hU1eGOXz7*+Y0%f{`V&lH-Ye()}RZ-mPMO$ z$@OK;_VZ;j9>R=e^Akm|po{%3gVUHU7at)0A}K}*Mn)FHhGlx;0MQ(4f=9kfy_VZF zfR1{9`0uQ2bkvs3tmyV^Pqs*&+>|21TISrMTVQgbm*1S+@*TViUXV&Pu>=UcjxywHpcP_16b7VDU9gg2_R@~LQF_Ng@oM#W4S?PVwgJ8m?a z$_5fX9vGF7UtD%xi71or3~iW1=u=5$lgN>FAr#mLIphepZ~O5}Cx;DYz#$bV6Nh0X zczb)3O7+GKBEv8AS>eF$wzwoR=&qL9V7J5$gt0WwA=ZCnUM}?-9y`*6mhM__bpriJ zo}i%NHmh(HnHJ&E5U?>KS333sx1%z4pUhI4XCq~c6gW)#wwo*3LZV@<`M9FuENRKK zzv4WuKGE8neoQp_z!~9nbt(aKJ<&nC8HVhXVC_PwQBwV>+@wIfzc^sh{MAfXemWIv)ctV!lIY*gn9&CMQ%h~*Wsoi|#{H<|e8>5q z5e_(+>3DFCOpSVi%+$-PNjjeCaRcYI*1T=*K170?lZ{J`H;`(16a+Cm}m;}4QbNAeciC#dVJ zX1^>MsT@^+jdYS7a8dM^eh?=mKlr2T^X+0Ib|w`W8&(_6ElF!4y)hN{>?*8IZV!g8 zw44Ix<#6rPyrl z>+$ZlPZB+%X7c4R!aNS91O@d~m;5Sy9e-!*VJsvchVXQ~n5d2g{!^0zGzJvIT04-l z#`Lt%7w{}0Op1s3`Tnpf51MlD`he^v#>mY);f!+3p08;R!gOzh|R%Ex1hkcOZVA?t@_3Z_bol@hA9 zlNVZPZwQm3Orh|%Riq(7n9ug~Whms=N)sjB~&1Je+kJ^VbF+O~>m;QAFhI_Ne`RTt^!%M81|tVZM0-#M<%g*;ml* zS|OMwmgn4NbQ?EG?>zt|2BTV0RaI4CokaiidFxMqznG`6hVd`GR#p)S7Bj6j{$Tdo zVSMc&J^~Twv$HdUDBB8v>Wh&3a{P=W;6!IO8A0_kv{dmtoE{8*$9%dM7q?l|G20>h z2O*pL=U)Qn$L@~%A{PHFogT3(h{$w&Ksgsqe^XU>4^_&RYN_UxK9`*-$ffT--AI=- z&|Fe$3t8sB?35O1nQyi;%&V)kmZG@*?L)`_P}!F9}JNn^VPlW^5TH zOdMUBpvi6_leI{Arhl|kuk(pjieHXzO#D-|#bj#^1)?dt>ygFeNmI76`CTd37jj}v z-A0bAKJ?OP@te{XJ2YUR;!PT&j6ALo9{_mkq>wTpcgr zfK3uRiN_Jf@*SP08R%z=#QNgsY%)>|ZxEYZO#2ot2fzWu1q4VY_Cj$!ez+5uu#vDR zgov<#5XghgnG0-ED3y$sxnZ8NMA}e?()#Ts+IRHT=q+YUa*(gK_=j49Fpb#=L*AJj zr9Q}t6Wv#v>R-+mPF-dC5|aByF-`{xII;1!ITA^)%~YmqzCTv|9>fuy%;tA)?M zG`jx^!@60KfQ<-?;M<_el~Dxe{`H{E9*(XjiKEzRO2N?dzP#sc4AJf0E%tkb27!zM zXPPP(Vx;n~4U)>R4l!)MDfR`F(0_9yxv275QkAPaAXELb`t@h&Ys4q+B%&yYM155` zXXDKLpI;etj7}~nMq}%ra*k?ScN@d-3%ram8ir&milO?_1nb`|ssUW1(coWLYy3lq z_pG~{TUhU(_w3eim{{6iIgAkK&50!B4jOWvEH~Vp53tGFX5hd%0g^E@`Cyh z3S8E_)K1kr+}*>;vj7#3&{>;UVA-`vj7}#4_yelf@5abkPC(f>1!L^TB)1FP0x0*( zSNIkb4%G$c5m88I>kCKNK5g}@BxToNvTC4DCc5s)iXn(sBFHcP=T8pui3~YDH8c)(M+z{%_d%f z%Y^;ws8ImxfJ13MexYz>}q<5DL5i{i7+OCR0B@mVbbL26HeEm5+p|4yIQzd#|1Y}B%k;I32t zHccyhG~bRYh`hn?Zt{qH22W+ReRW*ex|Fp;-lUD3EK?A;`?m2uU7gDe?Li8S9Rwp z)FifpUwPt(Dv7yniARC#Z1_UeCe~tz27Lni?3OyR;h&+`&Xo2FQ}^WR7?6q zFUu{Bez4q_rbG6BqmPl*gG56-WK$uOCO>hwEW!Qyu#UX=> zHPXDZGxpoK&}lg{bgtX)GG=CpdneJ2JKUsSJuFr*VV-@u~5}RkGvL z9P9Rmo{@@&!qI#UG-yzGTjr|XQS}jIt)e{-Te%rt7#^%8(kO{gUUL6ygEq0PF&I!@ z7k|(hxy-A(1qAk&vgAaj$pHq1JdwP!&&B-#Y9K=X&97!w<9}D8w{HKq^M<#n(NuqX zpIsXsad_jzP1;c_jlB03wNiR2_@ahm1Kaa_4{c;86Jw&{wXQaxG-_NMm0;nk5>;+p z5bGIB9ypf>3Zsxf-b%-9@@eaiv=DUAuLnT?yhHCV7-)_{+IZF@a#y>9U7_h*M{U^?Y5kxHjkwWxfxUN__A(Hz!UF}qwpPJ=4&y$-G44YPJ*xxrx8}I0V-mO5> z~P4D}o+aH$i3UF(t`*Z%z({Zfe%6=++% zulRznOglcp!!0H89bf$Tp+6<#{g&qhp<^W@_EYJ>hGF%j{h=@oiH|)~SC= zE%wGWCK{Zm1O)V+UqwbH)FjLza_x8?b%4;08_l#i*MkP~iL0{W<3lCGZ#fUvy$ie1 zxt!oCr}4d`o1q>u+`0~Ip1SRxM;(_~$|$>D63g+)i=9)ICY^Hx80W~?fr@YAD_;^PobxaY1VEPYHj`!U8j7rI4QJ~l4_tC*h3$%0= zVsk6BQbX+DPj85NP2c|whu}h)k}u<*Pl^%27g0QupRff8Z_DYyuXXuBW32ecIyp)! zH0!^A2UqhA)Z@J|4I?cdD^;(_E7gB^c*q9^#T@;yQdUDDLtof{j&O=M$$nVqtuFsm zoora6bo%=BFYnWdk!>#&1I`3jw#T_zF;8LMj=kBqeln##$ap9E&YdKt==d(gpNPNz zs!P85zF2ku?8VjE2d9@3QGD=v(|vrH69;nRZMNbxKoQUCC(f?gcF$cK)B)-YBu6iMAP3Pk9&F+L$$+qq6&*i`>Nes`#h7 zr9-5HF1^o4pv*~^NpL)!^akZ=QTbAq+E{4T{2j>}e46YAXTqKgO^GISiPQ7E*RjpW z@*g73&wE(2;wNe44_;F@sx|j9VdIz3c=3Q24o488#J_kn8-*b^w7CbFsnvSFUMmP% zAaVvcq9ZC#$B2@+BbzhDr^^>}H?E;{#DEJ=_!oik>9-7CRUbb7qa%N?_C>rRg$^5| zB&wi|=%Rw$b{u=(x@m@V&M!D_TU{u)B5OPf_0qv6HCnIywVIZWy8}mqv>)ltX@A@H z%R-Y)Ht`0jbY@Gf4}Y~IjlgB&uYfhJr2R1>P{ZzQhhF!j7C$Tm6nLUbR31z#&^ob|$o@3w5*FnVloMOjj6)FJ^;RzS(>HwOPjT;5 zgj&6W{!a@axr=@3q*wr3&nnO?srvjKASW+;Mw{7OEWoovr9W*>ea@$EL57;cZB6F` zOfn@h2dDkEfSFlw9~OBA1H_**gimCl4G9cxI;>5}-!es@YierzBO)yWJn9#p=Hh-- zk#|UEa-d;QN>>j}bS_Pzi9yfhnhy%z7H3@n$?dN}X^pEGiiK$neWc5}1BYIQtH`D0T+Zo+|I=d5;@1+-YurGgtB;)^)MeZvitgS|x>>B2i zvUK1sXZV^psVd`X8QgdgxJsM)GnB6K%9tyFVjiV#I0;`UHHUUZ#{hV)Ov`?VomEsJTT8K_w|> zbQice(h7H49Q+cf{B+uV@2^->PC9E~rf)6IeNS}&V+ZLO~hj^52pO(ealZnJ+NX@Q|w&PK1A%rF}dQS1)Z8Q zpyFePiOK2a#JaF0((aRdHt@3N%+}Y#+8LHbXUUOa962fl*gUxgH%}&pWtLZsoGGS_ zjSU(xbCHFT8p^cuO;;zoT8`RI*gi>=mZ_%<;#gQ1U40dB2nB%`n|o(j!?@1~jk1kV zpQ*krB|!S}m${NSbv=4pAeQpBL0$3-Hl}MceMFSaq7@0-S6=6%g~Ul~)dW`M zjWX;pH;7&mR$0Z6>z{s?ZwcOz>N)15r+HOqAit7>gDH|qLAJBAYi!r2JytZLxHwFduffnCz?$7U0v7*VA zZqt3WV-L&Xj|xNtRCVq&j}HKeMt@^1t~Wlyo{Zp-&j?jKy9Gpq^Y)&ahjojW+;vjBO7a>GKtWR3%PO3O{J+2JljTb8K4Sb_DEcHj? z)ni~hMD1036R$(w@RhcQU*E&Q`}m#iG4$Ps-rIOF`u1J{5N^2^$LvdW@)L?sf zhUZvvinm(h=S9J6JnF&aL-dSTHZCg|mzgAj6%U}^U8mT{ibAiX$f4vW$pWZv58-?C zmR+#JPr}bqzm48-4xrYABXq zIc2F@RhIgp1r6Rj%EmsNlFPvqMJ%0mHnkaNg=Q1BC*5MV2oxovdMkby=nCGHdUP=P z48T3|_@dfY$`8)o`OnfLygaX0=%U2_(SpM$?;1B%D8t|yceLeuE196z$cDSxQKRCU#_Z6YGuQVQm1*GF#%DU`o9=PUm@qTV@#&(^ZuMLw*! z4wN?>Tmr5sl^dsjgaY^!?Cg5?j-$peV45gZhi{4|)LHI$+=8Z)j|Kkccm7ZesxmaQ z;yjUzeXJYUOo}x za4fS&?);W}kqxuW;Grm$sLOwAy2|g-I*MMRH*E{M%rwf-{kMqM2KJp;Q893WIsDvDP3WdXE2gP5>$cWK}qn$)Q`sua04U}rn7k#&)gcn;;ht1szVn1c@Hl%@6|FvpR@^kX4;4P z+WX~17_<9oIOY<%$0zVLU|C*()fT>k?*7N}z8SQLeSYI}`1AE~)0%&nI_7;M<{IvE zwl_t0eFnO!K#hCogd3=;V4c^3XNkjb0o{RNi*YEKgMpXk75|0D?w?q73~x&jx{FwI zezQ2zr_^6?+Ca9lbU17c@qV*vM&NgV+PeK#YWA!zLF*~jDeDQ^_Q)Amo1A8bmHMiP zQ%%zIYl-W8j)@m7oH6s@8gz@rmDjt}tq2Sh-PleF;c#MwsoK=d_jN(O-5!i?(>WPf zk9QnP0j!8E|0T@{r2j4+!atT$T4)~;P9YWO7Yo$%0UEFt#{YEm?(BUUF z?!OWTA9UgK=LvIOFi*>@w1Q_S&pi|mzjto;Vk5YeZh1MBwalt*v6ES^_b=(UT8~d0 z*4{s3xyt!Ac_Wi{X=QO#+0jS^qN?j-VS)p(5VmN4DUX>`M~ zGTpUG7?btA0-nk~SuH8U*T|Jn_Gq6K0#56eQ)WZkt71Y!Lz|9}O`KlzIpi&<9deqN z4S$M4JT}xsl3YYaFc3?bJ{*7ZU14B0o^A89SsYp9hik1I6B`x!aOS3DuA5fGyTH|8 z?)&065Q{)`Klf3$(4}!tczBy*^PD4`7QMsJaVzvFu>nzPb^Zp`BFlY~AZIoF6u-Fo zjZ?S(PZz^8a^Ll^PiX0;jpn~E$&qtAqdd_NdqR_dUp(*wM^yy~z2|%qgk2 zsYrRDg9cTle^8>Z=kCF&l_)nc>xsYsa4of9q;tATM19%ld+u;1OSGr#|$mvneF`~&Dw1MsT z=IZq)A4$Kxx2i)aBj0ZybE~`*vlH4@`B>KX_NYn`XnY>eLT(A?zaBk#X3gAXfXq>` zY@)wF{hrpwP8W@w*&3sHlbsi6^9hq~J&1?%8%hP%<34BQ8hl^U5(#zu_=l^ZJuEH8 zUAO?M5nslT6yL`S)3H~)?yr#EzaE0sVlg$2qZmGyuzZK4-z`IvLH9sK!p;oAJ2Ip| z8(eMhVK+q_f#~@x2OiAS))O_{7-7u-7X}QD3h)$5pVO2aryP&HIsOf(n_de3`UjlC zi#BbpZ?ZnM6*IndmcCHtCc}KLr?1{&)5~A2{3*hCBO7RW-wW(CGbaDQBWTz=xC<4! zrj9U}CNU-5!C~@~x;uvlMPh_QwfgRgMjue4%>FF#|fEXU|hj@6BuXnAQfmudOps!jz1$NXDn@CaJCdTU4v z^Qx4znR5Dvgu77u39s>w@{*#{-5UqNRz`0Jr85+z#SbDSJPotD^qP`L1wu%tR-zB# zwkr}lUP%uj8&AT1cqZ{Ebn(>E<(b?rl^mv1`c`lnWLO=^_Ct!$LeLuuB;ZaZTZV59 zo%1HrGdo^Tokhe5ZN7F4Xf9ton9T3jM0JtuW_D`@Q9tD_^wNsT@r8rNgI!!fN!f^! zFIDm)W+*5q`3@wPSaB4tIy>y$vD8WO(yxhenj2r@h=^}uVY}5?6AC4fV%(`qlrQzF*`?lU_HQ((3^l|UNo zTZ*aK>eUwc(w8gEWOIup(SmOG{Qg(7>lW0q)^xtP3ehdrxpE=1HJs;0Xi~Xq>HsDl ze04Y*CO4zLl!M*DK(J4~KGQJ&iN!yLCA>dfO-oB#uIsS%blk81%{^+AN?8b42n50Y zj{Xn<*k=SC)^7E$q=5oEKc0XfJ)`_Yy50x#Dp*(HVTHfrZuk8L!Jc2R4?3G4vi+uY z!djF(>vp5Q3i=W{S6?iT+u5rp9LoVUcw;ef?ic(o&75!be!ceq6M}EyCFdQMd!fASmp{I z;1uyF$wEqusIYEBao4%+w1ip7mm(|iF@wI?lCC&sA4I;I(m43U-$x3@()F52c+VhO82!mv$ploCbJ*}?&d z`ZT^R?jtefLKWFvRe7#$U4VGx&$=N|B6zbGeE9?NtKlZvt*XGn48%QjdGpr^Vf}AnZBP0}(N> zPvF>`a}(Sqas%kFMwt3vXui>P8%<>?i+p^}^WwoKH43=J)xf z2nK?@D4^H>Vp3@qx+9%F^jj1ZfT1UE@u>%Gd{1J9AFfl@lrh&JY&q9ftiftQQF%}@ z4nG#3vE0v1jPKX|JAN`3Rb9^|6Lv8DT46bPT~3P0bJX8G8Q6w;H@RKx>I3j0>_l*Y z>)cZbPSBj5xVKSjNqSQ$P{K?4M2U>&Vv6kpr#<<;D;p(3U@?IFRHd2hByrM&Msa5v z;XUo4h^i`I!AjNBSmfZx*IuV{weCx19 zM9Dp4p{I&U!^QQrD-tb3S*?)(zN<39MLs{1fi`)@bN^UjObfny@&d|CvGCQ`*-J0W zP@ku(lwtt`*w^gz-DL`iKI|+-hW);G7k;0Oefa|X&sO;b98~C+xSL@UZ$b(1${(k(#dL@=-+m}?T#Y*% zjKu`4ZUWw^r9ucEXRmrx69try93sunQT}3{SkzI0QzKKBtWiA@JUg%Ep}=MXJ&iMn ztq83M4UfOS*4gP>XJ{L6nw=a9#^rxRmzvb5-GTK6jA2+{mY&*$Y82T!$1m}l$$>iW z^{gpjsth5-k=M;;t37&OSVnj5it*il&?(TnOt-VcH*?G9XWr?>XTR@^Il+Mt30UWW zi&0|KjqQQB-aukQ3D3o+n$KB}=J&4Aj{nL8nCs#@q*6YGc~e$lY>cG|^o5IQ4n}%C z)$A>gPo!cmen2g;$THA8GKH>lMece{1BUJDtQ)44(1&Qz^ZQ~>vUPFtSSSk1t zpop?CvAE&}0%3lI1I4&JQK4Vvr>O&jMTu|E9GZ>CLSMVp_OR=M8JcZO!;}`HA4Uk$ z8<`vui443mv0C1r6VN&9iUEJJLYWYW(P%kIN1dq~Aznbwo;4X47wpP=oDjKBGAiio zNL)7R_p)%LCP{S!NzdNf18Aso5+pk_--$5B*lkQGxnkRwvs1nFst1N13lBvz^sbib zypPST-dca^k7@M?C(&rdT*}&|=_h5?lR~04zH93TkCl9hB1N5be>tO&V;plaaI48S z_!y0%h5&sbTu~sOf)Y4KaNsFwr8!2)RX+J@a-hcM?1@BO{RIJv!*pBv8mp!lFOrSi zqAYeGK!mU0aX|{KS;-R^pOxlre{z9cL>=o#O;xfuy<7h?;vimuFRXgIGzYiTF`W+8 zO^hhDm;@htvSgXowsgMB~Y!?zHEV~6N}HnmYpaG zJYVd6EV92Nsz4a;LI;qy_l8j)vN^UL|Ib0!=8^>LR@p{OtU~u5T`xks4 zOYUI${oZG__0i^+75q<=*17Yu{-^Eib&Dmk%(*w^{yqjzO^0+xS#EMVrEP zLp8bX@AAdAHV2-weOzuD-lVpSpWX?plBBs0XKuL&X~y%%pWmJtf`i$!8CIfAq(7c# zdp9Sng@lNfBph_HBLHQCvVF}GDV^^uj{+!(?{D>+u`XfCtw>*M9dx(E@s^dXeyZOh zT09g!esWw&!vecMbregX$Yu32H2snr7E&3u=z>A|eiqqa&>>3oO)t`=#n;!Y!)Cv$ z5`TPl20eqO#{{fIK`|S2LoodNSdP{!0)7%pC zEy=5mgr#APRL7FC60_>1!GcYUV*LoNMO25!LBub}fU#6xScW2p`Iu-)ZS+@xM8ajM zhz!;gSdn-M+B1q@2pZTnpn??LGUc{{vqpxg1o4E*jzO!D4=@hV;LItGo`*_=Z*=b4) ziw6G0pooZw{R~LkW?w`ExvL?M86x528sMLaA0_&;;7M$L#w&vqE0@!C!I_G(Gya;_ z{>sFSi4vNc8#5~v*bd&$_x~8JT6y07SD-iM0&&G63LzPe*su5!uOi;)$$VU90~@2E zFFyHJu*A9EVr=hMOxD)XEHTyJWZ zC>7VtM?{n1sZce&BMxL27Z(7-VgSt?Ge`Mg0u*V(<@@tj(_@&g_|=n5V95V0f4}1c z2${$oZDG>dTU=F-MrJv+F&T@lU3~QpPfj?~{*fpk5Q7iUu3*GS#LH!A4tpMd+Ip;@ zGCLxF5Q|&>ggaMJ=_@#koh)7F|7Udcmaygi2=meFky-GOvvH8$VwQG^GY;1M=Hz1z zsd+#%PU)8Np>RzV+`IA!r*X03lvN4%sQBuQR%yO<7nhg4fP@&Wt#WNYfow$XIY?cjyip97 z)$cY!8#)0fFJlOqXh+`512xlS31kpRU>>ZMBo12^MXU#@C6vURsFjsK)aEa{32-rZ zi9^IA##_zcOti1v>xguTu6z_g%)Weg5VZ+qLS`O zKx9y&d8(2j5YTL#0}C1DYc+(bl0^g$)*Gg|-L6q%Ak$zl8r=^|(;iD7l`1)Hjxno# zHST}Nk>u0li6&QHrK#T4hryft#EbS39m_|4(287N80(5S4XK(%H0n2I@Ck-tdHBF9V>8~C=j3t`I({eXKbRJs}% zG2aLMm})0Nf#G3f9JXBBp)OkHzpls*xRUagU~ytq5pUC5{vCYkQ;kA2`+y-Mg$keByuFwo+`XfU{9{%Hipx zwA!%QM?kRhVjGD`FTX?UHDr_vpYfd$m}1o{@#^Cqej=+_eamPuLmew;3^440c!~&d z^lvYmwzYKD6WM`H(2<};R~Cz2yMQ8)Sc+6<4l0%3LxAE?L_k~(j(?D~oB=NeicA&9 z{30Vx^I}?#zt6;8+6=xW%ryoio#87NzZUHO`(q@HxlEC#>^9D3-n3$7m&L z#|ZjoJWu-0ESrYeg8CYN*ClzWi$RxvBH(eY-j6&(^m?@XeNwM?Xq*qBgN=#SZVtdV zewSmB_IeJ_5#610!{@D7RHRs#kFD3GrN+<^H+V@%i~54MViYQcDB=W9tZEorNQv%sZ)U zhxa#7w~_-A@vB{sPG60cj^ql5P#S$L!DEi{!QbI$;Psc*SpIJRJm0*B`5Vb9G<6?G zhyn})^Umw>uU#?SaO!?$2G~T~TPV{maiqbD(;jKpy!$(66Kc{h!oGW?{UO{d2{Osl z_Hfb)-^2Qr0rFp|FclEsy^D%iTJ4BWtF`G=5l+GB3oykuGd0&;cqmQSav(>d)~LO| zWtAQkJ|ar`nP=S&YyUiqfB&1qP?&Emhgu0&yg9B=tO~LzsAwfkfIo z^x4F-rw{@(>VP{IR7RFeov5&z9fjEozV#*DMi^_6 zX1Qh|#Sa-<#?F+uD#P?k>0cV)7`pQ`d+1Hd`a!9$ifTUQ?o%MM=MSqAk^c9OzR#M= zCfBBJpZ?H*GpVL3c&@JwyJGlXx3uNkeX63fH#z&ou<{8=S`P{|kjn-z@o)FuA8u?0 z`1Gxqx?G7tp>$&6hqVFF>%!dktzGvI<#Kg}F273CGCi`2l`8FZL zwz3MLa3J^}yI_IsAM%SENY&B4&iJ}A>>QozM}$MQj~$>ec%pX`Kd8PAsKwpM4ve+2U%&3>E7vhyM@j8!Xr^% zV{x*npM8r46G3OASs6NEtjgF-d>f9)h%*zi?X2Kl2q>srY@`!utjX8ecpF}H?ZU(>zl~Gx z#`^HCXoUFBEIFAy3ChX(1KJ0X#&)Jp0YoirVb4?K;r+lgt0m?j4v9%W06O2d$|OLN@6i44@I zjR8b3ezCAWwyozWUW_?F2GZxxC-@_McEvXGJ1%A@FgXTts)(13dB+6x z_vii>-YGJ2duQ~c{L3i#+p7y+b?38O&jQV0ZWvIa?0iUMm*B;hN5=tA>}^^`t4}p) zCe~|^5`}TfN+uTFE5}bU4vBML)#990reOAte9i} z!oKc!ty%P5ZTx*bJKSgz8aY~>{{6z4n zbvxRKxwRWCu((`KtQ5Jy-xB5{dU;N&7UH{_Q+Wz6O z?Qr9v@j-?)rIR}iRuOQRTT7nnq|wJqffUaV(pWg?Rd}r($e;Lm0%?ws^OUm|Wv{X4 z|2b9$5%i~CHPj!KEtnnO_&p)3M@w6D`LXIXdO59Az7|r9@iHERyp)Q$y1RGJ&SI6H za5?Uy{?{3^1KH@35cU<<7;)<(zEo}!xGPw>tF>VkRaeT3!D zrZ8pkjzN1DLT<|wIzi7vMyLm1(>UUCL;0SOC-8tfCIvOZJ>t^pLw)Pc$G6^hE@s`ho!0HpSm&$_hfKr2B#=02 zVZ|lI?Yrgp?nGk0<0s5C`|x<+N6$B_&D-(F&q8W8x)C4Uj0%W_C`d{ZXY&1Kn)z-W zgS;9p=+jlNn9Ox+o*H5eaT_XOhjzXN+AJN(@-m1cj5{vvaNb_6>3_YRl~o@2LmnhX zVZ*?>-i9xKM2Ft&_c28yAhD;Jq zt;Fu7{@lj=*;MSi%z4x)J-mn#Z>&q#nBAAUsB3$=Hwl{Mvw!X7V=0x$bun8m4t5?5*hD8+k&gFFauM;1_Ub-)9SK<3s0Lax^5#OL3=3ycKcr5RQD|HTg&Y`Z~n_%4$KQ% z`o3%Y|Dmkc*W2%n-dHKwtinI0r`U zH-H_ORdPLDhcYlQSU-GV^Njl|E6wJKA#N~@T^O7vH)|97F(GA;&KUa8T<>NU?At!C zB=<@ph!zK*dSXqIo(!%2(GJ-3c3QSmk$KeOAUy)(#O-*_hrN}aH#q20ByKGZL~hnW5sKtp#2$Qu-8(}HCRm)qSrMaTVLI<%Tf5XT9K z#6*KF9`8pE)heARklm24)nZTC4=lb&AZS_)eIk8E`L9LJDQQb{Gchf#5UE|{@VBYDyK3a0emOvWNeYU$2Qt{JG*)Q*wv309bHVoPUiF>VgN^TgpX}_t+fTPB{4gNV`sM%T z4Z}QWBb(8RFJJ?5j$e5gx^TBk{+vVM0*Y7u#|bkmD_E4ms@l(mQAD{;ht{ekYKQxY4XT2fZ7+pke7*Qj!fwDS@~B zVcZ$L^%+T?IKYKs*+dIWNf?Wb-{^k7;0Z9Wu%)}c)z<~;yh1$*hSc&ozN;s93eVR| zHiCxN;5_8XMB%f-kbJZgDHKPNhX!RGeL#RE`h%X?#7Xpo$cDwh>RGQ$pDuh0Fd4Tk z(=?ca<{TO%oU1(caP+zz1dgG#Ly@SAjw7m-S{g82X%u(_1d*VfEzIU6?k(O>-GFBo zrpw$2$R$(>0O$-p^-jL|^MnQ!JF-i1$jQYk!nBL|Ei$KYCf+E7zVfH~Dz;}+U8LV3 zSCHzA78-cOAF%b_Pwr373Or?W_qN2il5 zm$6wd!!51_(?)WjWVPVJ>?~v^IUz=;$Lo6Ey#Xhhq!u{w=>^H7`=5V#qly0{@DxX^ za=5D@O-=k%aVKl8bdk{6S>Xo&7E*v31Uw$`_<+$cpIvFs!Na4|>I0_gk@~q#%16z) z@&o3xHuV|o3(sVtUj zxSs9H_{HHt0jHgwiY)NmA%iKrTnBijDq`>-rkSrW7TkP`;+0@SZY#d%(Fo|vayub7# zocN3EvB5}(1!z4i=F3Z@RY6E`)k}&0%aUW%f`QhnD)g2*0nHn7X9_X}g|5D(HXodv ze6sonsL-^w=skz5X*}G#_A)uM=do;lB2YE`3)omh61EXZ&%A4@461K?Q`9|ldWMBv zSz+UW43M2uY=PJuKeeWc|61!ovnodVgec7Ef#)-91Y36RepXT}d(l%*_u^-2=`~`s z1nOj%xI~+E5KJ>sYb@2yP=ju^8B>n<5h&trWqCme9<6&Rw@i){*ZuhGTh z4~3uZ^@jhaOqijYg3br7rz_T*-yP`)5MI;dM_l@&vS`&#HAU_jWYPl%ICPyPHk?kw z-^jH??0%;LbHyU@*D8wHH8jClz6+lYrtq(*C()$K|2)^E>dw0Kc z4O`3iY1zf3X3K7V`q+HJ`IVm$BO9@1iE|QulGG7sgFKzdr}`oE&376+=tttu30b(A zEZet}b2omoruFCf!|TZ3F&)ECxi)Y9l&VN|H57jwHD%+gOHK{zkL2;nQA8dyB^&|Isck{SXB|7P_FKNhqU-UMjbg2^m2FiOw(mX;V+VmD#&K6g=G z^+i#DmxZY+pBAB#ooIL1s#=NCgFoBEH=mwX?gP>VTuQyDzN4smlc2yoSkS5miI7me zKc=(Co6k>0AhjC_ELrzU9pn6lEUEn<*@{jBl1;jQjU3)06{!r!1WRv>O~h7o zm>&wPK&f&`yu3#L0#H7*J>`S4T|Ixq9dM)dn`V3MHZuefw9q0k#tN^J!P;YPBdU%m zC6bF*GQfJJ20`|+ef}P&_)%rere)5&BpG}gsJa^C{ui2q;zI_rGL%FoO4+^%W1Ke& zdPf*tPP9Pt;RNv|0cUv2O>u*q+;=XDj(3>CP-Xl3`!M@ls{7}f@rvLrys$Hhg(I)w z3RY~;dw*HVL({j>s?sixJil)-pb)8}h9TNsx?U8GuM7+1{ywq6V?@`dx1;a~7p*(7sWB+X6v!od*hz0|PhIOy{o?>wDJbJcvz#lh zDOn8D$^mU>Q8(x3pMUa2Kdk-AKNYVp&0lk|nEplSu%`)WM`VbhFaAo&v=f$lb*74X zu;(+Dr~Yf1dZR3CO2RGQuMCQ$8SNI8?5B$Ot(Ypt!vnN_P^+HfIHdS%PLy&Q7K=d9 z@AQjl+x6fdwZlv8)Gxv1loYV)xi{Mf{zN9;EbLbE!~|)eU#H+Zk+?UtlW6v&q<0<| z{~~^r!;t=9f5^E!UQ)v?`Vok%$0eu1saGB!8q6lOW=@>&R#X#Y#FCv#g*iyJDZRL7 zUV@DOkx9k5i7NoTdhpsiuZRA|VTaSh&YlUNh>5f^bN{%sS)oLrdbEW47P`z8cVss= z^X)1agx`5Oiy|6-tF?jkQ%CFlr~Bgdf7NEt7hYPF_gkv^c zb}LXCvwoIwBd^}S-O~JjT)kyjT}{v>8azM<4he)D+}-`4!6gvf-JReXAh<(BhEH1}=Jaww-xGBXO=}JqOqy7i)H%@LKmxw||l8FI>)f>y% zoko&SuCD7YyFb|>X$3q2_Fm{_VWXgw_73Y(l0IRDcamN#FR}jN?-BMyo`f%=&zL6o z(;!|%a@Gi#9EQ!Ws-xr;9lTD>(vo9^Nuxr-p{4wtW#j437uppwnmqe+J;A$Q)@R5cYVWhiV}-Un1XBi21r68T?E9FHBkxb3!T^2B?g=*K{_)x#5zHvzt1?OtARNz2#xQIah- zPmZ5-4S9gO2Icxcpx!>46Mafelz)m~^b3R#t=dEuzH6v+_%n|&pBkvAm^B9+yQJNb zL_Z?o-=v3s)YNEGP5w}_2rEi%KQNZMu3mT2z|+wMO$6~I%+ts(spKh{)&%fB7Fi`m ziU#~8Sx{LL%mKLG6eICaBunC?6#4tb+!M9safBGlm~#iFQIIR};_$MuJEAr;S69~DOA*QM z4V_=4z1dDx6qfzP_^sWE zzt)Zu&DLJFqlZO_=^?I&W0Lm~-}xfzQ?qye(rZ>oHuMOv3#Z4WoL}#}OenrI zw&DQvLJF#2+3jl%Y;1&1r!miBI&`r=W>D z&H*{1Jf0}J59R{6O=EHkDqHe{`2pRgGCV6!pQlr%3?D7{)C|ddi55SG^svz zWOH-#?C+t=aUk^*Zap4rs*#^Y*GV(w|U4Too&i0UYkkGiy75 zNU@i(SH>SQ*-G6>(@gWt6DN32Pi=j6Tgziv30!AwA2Eu4`jhX9bS^zT?OkciV6V)N zuS`DML0zbrB!@_jslSFt%+sOaSmVtfn0HY|1?DkAXufLhi2sftbCFmp2TwMDWgU zmx>7KzJ?wbuh=Oepd8sQ-|)6fy%s5z$9ay&P#K;QjfmAP{=)5RgX>xD!n$35I}a^XeJ;K%`ng(EI)%1{Up?w4d;K5$K(ChRAXaQR2Z2T7>D~BRH+xIXv0!!_2+#ZhNCMdaSZaf^8Urvgw%l(1 zA(Go%pF$uR+zBItr=*#upWe^oz%9kT_W7e~Y47G7s@}RiQ0(WgFW|FV6wkLU0Z{*i zi?L`xOdk|OA+N{N7drF5(yS?s#gV#>wg*psDFZHN8$s`TzBOh_j-N6$_!__i?o5l0XvYAp5*VeB{3LYTr6qb+ zP?5hSM9qvLM>3VbSm6jj6R|}myXZ`Rut&UBk)`3HpPrvZ8~G)&uC2r+y`4|jWEjP- zm58Yl8}q)fkW-N2bB4q3^KL!%cjS4?w*V+4zuNKh)|S!^FvZDh?8u+04{5=nQ-81e z@tAN5aKME9fPm5lz+YrZi7{HP{}!0OlbMN&>&t|>2w6q>C;0wU{*Z|D#q2h}PmtzA zBCA=C(VMb@CdN|M51*JlnNdBh9Rt?lL_oT*Sjye=ZCIs*p9e0L#^blwY?}%s*b>q6 zMk%#RYch9}6YIvK`{+Qth`Beo2hfAk^sUM98<*b<18Tg9;MwRU`^OurPBGz06I3dz zT+bbGn@EH3_)`!>7+#EF#XTT!qq2Zxe2*NIb@P(B0Jq(hNk>d}ZR+C=SG?!yq&&7H zouH4+o0;^rH;E^nh#Haw4UXxN2$7p-U~=xum-n?!NeLbLql~f3mSs zEPrLsfh9_7!g8QBBb5d@1JH>60UOI(ox@L2*7z`O1$I%>xrwhM2{O+%6FewFwr(C0mkBz|tz4?7A8<=inGhhJd+0dQ@No7B}styuxbposJTxkzzUZ8 zH4DNY>j5Cc9h%Hm?+@C_u&C*fu9TDcjpH8B)rdnQP!&#hioG5;mLgvj*N91~FXI1^78!(S#&7Gzz7A+e+~ z*8E4o&A8?qs00Imp)0-B15?(xJqM*okwtLdP0>4L;_nGo(qez(W-_6&NEmf{SgLMr zJ-hr7?*N=t5P=Rx>URS$kQ_WA&5O#y*h}r4mVEHRh#1pPrH}dY^Fj){l{z#=TXRa2L91mwwuE(XIifH$j{b&3CzBI9$cc%pf~x+ z6&QdIbj5C5&C|g%+OacEksz=aO<@Wo4v>g#pRFZ3O2ke?s?VAtIu7M^UCS$Ds>cZC z9Rv*XzQf^`6XG3Y_hzJL0A%Vxc)sCAmi^wIJC*ySHX|JZetbfV0do~rS=^Wz6PxjQ z5U&Hi*>s*@>4gmjz2kh~d*W3T5gHk~8dn$;@bE-v?jeu5fK%%oKV0Re#mR#(+Y8<3Z( z=iUp90ZHRk158y87YAcj^ty-$Ep8X=$&Yscba1XcV!REnPhF)0h;Qgk9L7k{B*UJT zqsSQX)qM*f{=c&TrGshzv|vJU=%67yf2f-q8iF>e6avQTernuojQC00zcOwSQ9t(Y zEhBBHU6FL&#whgMAdPW{D+{$x+LCR5d7Z{qrUBXmLphmZ??1yky$5wsi+6dT`Kb=V zDrHmGepJ87HB>*BAN8rPW1J59MuWQXVT7Vy{Bg)5oaSgN-|yP)a?(uN!W)*%UaG)I zIEmF<2m3g9h<{8^rx1~t2{|R_0IN&Ma4ba)uumdX0L{5;O#onI&Z7G{Tg(rdQcEOuam2?Ln#NjLtg2GjTwoC< zVPMcCok$p-33VyAM$mO^wZ<%Vy@iF8GOraH~ z$p{IxrWNrPW#2lJ+)S#IoTw-?s6}Z#ccJy~fSpMX!2L6#{62$`U(LtsNS5HZ-p4G` z%a1EeeW(OCi(H2Zl`E9CbF7!(#Pqv$9+&7wi4_m`G}6%w&^sw-=jdL{NWIG$e11iQLn(aD_s7SvANhTTSf= zfH3_|=dj2(rXwZ1jCXk71QgHdp0_V;qH0a2>z`D-OL-!kP@dKk`9j$OH|0I~kRp5f zo^hEnJoS4Ldy{4LZ8b(8&;)Y``eE!&&!rHJO!~HrBviB`Kd3&c2)4-L2b@Dze@BX$ zvABs7$c(Qnr=jtNcoQ&?r}9|n0g6G2uu({yiDP#}&v#pq0UKuT%OV!VMuev5rSi2M z!{>mWo}W>fW1cyW<7Y0FUvhoF#F0pmE5(8Z%_ZY1r*+#r?(L!T5n0j)WQ0uL8H-9v zc;c-S1q6QG9JtWh<`K>3O}zoX)?%MDhKd(pvuGoqfbRKNbjJcP2oVp#Flm08wp2vh zwVTySge)k6I#cVLnF$E$o;?Xm1}a6~SLjw%C281#AD$|>2ciHNPqeti`BYKuMjUqR|U@6^E*nhN0p47+b-mQ;*s zO&4bh-~k|c71#eU(Eax(z^q2uQ94kmP$9R(9dj{R%fjc(#$SFefWx>~9?;&X0EY)} zO0y~V_7Q%B5npj|G&Cpn_KNf_3;mXC?W636`5S8zc6tGU7a+aBF$hd zb>|TZeeMCg9P`iEOdrvL)Vq9mV`URL%Sn9`S*#7L)CiIoboimHMoATgRB6~>D#lGR z%vY6zH$l#R-FDlW@0!@0r59{04#J;H7QAq3sqeZdoNxTCYu_U0sHQ7UZs@L`&Dxi_ zmU_ioY};4Fr*gQBp){l*x!3n#>G0@=(pSFovw}cQSCi`ZUsK!LDK`G&m?04j2! zYw*@=LibB;#)pI)S{}~E%;qb^%8)Kkgoo-^-{#z;P?$8444mkEwgxQZQZ#8 zta6dx9LEIkQFq7OUQZ2z_Kt56vup?ZHyhzG%~31C3V2MjCA`v;ZJ$t3P!Q1rd9FG1 z-(VtJN2`xJQX~KSKuekLV>dN&uDk4HNSf7clJ3e<*C(qQ9ow-E{!ijd_Q32ul%E_J z8k6*k8Zc87mR&!;Dc*qx3vMRMiy>;Y5MnG;@2s=^DUllY<$v>8*ap}|Y6o_9LyL*T zzLP)Wv6fH*VXx)3-Y;$aqV)08TJKBIDv*7+0r))I;_F zJM+Fc3@gDS>)k|Tjo~0hbV?EXhSsR3<1(w+6e-kLS|=&F$fY}{=@YTz=>l)`YlvWs z^k*&kSvke?#H1bnR|1jtFexm~I z8(nZ4y18e51GJRFSpxlLCeKHBrGV)1n}3TCTA=47S_+FvgzY*8a-xt4A|m@w(ZLQ5 z_=7BW6oOKJs_9*(%wgy$CD{~!K~WK7G!*$wDjEn3w~?zVonoGJmf$|O zDPTB4{iW((q|5PkusbmB@>YyMfA_E7g47N#*7C3dr_(|We%zeIi{#$Qa6M$_2&)V3 zU>K*ijQskRU4IVh!-1x787fO$sTxQOeSHb7L|1lDl8@!LRdpWTqj_Xs-&Ae6>dzlN z_s2BUe`ljs{H_q^`KpmEW{cpAO)&GdG#rm<;C5pwy8KHF&eYJC1pKer`tHe(X_8w{ z3abpBZTPjL;i70Aq1bJd^Jx@gv|%dZS;d z;{ai;;OL!!6x>-)Hd|*hd!;~)zM0rhKgfP!nqz8U zc&gw=jZ~v~A`mr-+)91>YD%|rDH&+7;@ytNiD)JfK5=E|w4#jXZ3PFr`y#)>h0Ct4 z{%+R$i?LcB=ImK<5u{dmzY@CaNISO1mASL;V#JQlVu}4^$bH!;QK6|bcvb(0eu04n z*`+;(OmV*1ijIKW0F0#K&s?twC%Ks_XJR1k{7B53n`wJGX2t7?_c8q!Fr|B<=NxHp zU#G~4JBzwlC93bmfFh?UDGtn&!|px;6RM&W2@5?RejOnd6CoCNe)y_2`L|YLAxaMX zjMXDJhocNS|LS^aeus1wO!jf9p@9=>M^N|&I`^N!4H5gDxTucxJ=L1kmw&|#3#uzk zt}^q{RN|HRm2>x~+=sNxSdMsrM4 z%!_Jl#XKBap;A|-3Q%p7z9vrp%(Bqgp_j+-Jf3I+{K*(rJ)qlVzxap=&I=SsQN1xG z#`1k1<;q`i<<)7ZZLkKT=rKdwlth`zseefo0?e<0_-howmP<2eRVQ=eKdDJ^jJ09s zWZ!9@EkBrA=FlS@|asSbfCUA zoRMM->gh^~izz6;ZNRrh7LWaQSFWjbec41lf45u0{j1H2n#6$|rJhh!4P^FFtKs$G zsnnLW%o-f=B6?|TusYP+V0)AE$ug6;_k64wly&QK8!Kf}bW(jeprpyU>D$nNHj--X zZ4_BVuJeL>bk;)V1|eYaO36b#xC0O@j4iGB%r~6(w*H6=D50;+ z>m7&ek_SM@Syh5E2wg=eYYslXk5NDs3 zSou06qw;LP6}jC}>WjaeECHYv0rzH<2L&hkmEYsNvT~JMP`t>89~Q6HV|PK5BQN^{ zdKH6HUNdX4O7f&EV}I!I-x`qbo#d=QbLP$quTcL6+uSZVu?+aX`oJ{TlO^e&LY?aM zEJ)anTZ4f~-AJuoUw4Bj{v!$6fB2TAV*2&Zrpqe%C)T1*M;sAydOFyq0aa^&I@mG- zB?38HlnBcHl8P7Gl5R&o=0sVvC#fs_{gI0F!;^tuHLnxZr>|W0#`^WqtQI0#+l%5cY?5= zGgPGQL%&b&Jvss4pxfMAD_F~mInm@abP5`^_}h1!JTG1JFjLzEki%{F?TG+rnh&5LYK^ob{e z!OUMst}64@L+cBr3NHiFj{6o!eoicD_R+c}!t>etW(>Ne@CaqgCKD@=n67vAGkv$7 z@b3?wtntIBxbx1^JNqm&(F_?qjez2zRRG_(*U z0_Aji%d8NC+-(uS-fJ_&V8B_E!4&Rm@$S?58Vz@&r7=A!Ugs_uD-m8l~_$iI~|Z_osK1J1FO= zHZ7rRj7~us*MRh>e;gv${qnP&>f~R7xDPCEV8?2~5vQjE$Jok> zf$|UPxO~o{dk)c3pSc3;2}!Q2l)BVpKk}VslJ>}Z$lmWcaWBNJ$-+PUbEr%|8tm1= z4dTCZ+w+fXFXe zTFPY0z}q8#Pv4^4xP6ebHAKOyT*`d04#q3Fe#!FvJFy7a^-03*i!opFpF_DnNG?g7 z$6#GXRh@VnJ*q<`JA?#h zzb)~oJ|mq5dj$MCf8yO2&!@ZZ$kj?>J!`!&+35zOg&?M;>v2gBLcZMnhQ|<9$Ahu< zp`+fw^4~mqwKpNEq@Bq4aR*`cqUHyM_?F1PLBK&AAm(+uGHtmSqH#H{otF-gsM4qd zrcEVVlA*TRd(Vhq94oyvn=$-vWQr(Zamo=zT7MT+FdRvae+;+VbE5nm@hSWkoR}sm zdBdeO7pa7R(^dbE*<;4D+7ro$>LZ!%=o6!5$M~?E0K=_H@*I7CfyrZrTWQms#b94t zpeO!3ZTCL9XS}h^k~_{0_5PMv>_}uh0uHfUhtAzY;d4v7{eJvtmxK@O7nzQvYd}51 zeEklfNSXp4E%v#6JYdwpuZcR@{;EoYH~T#1+F&jvV1dyc<792LEg?3dEl-snG0MjX zI=a%?j@ghxc5T6Tn&UR1=0}Xe(A}Jf8U&A0F^9azUgD7_@}quQXb`KxttqeC9`;qy zW0C}u1iNB*g2qP-l0wZ*=l!)GprgP|1y&i}q#341Nbxw$TK<(GrR~O$m1dh0GQ+hr z{_FfjeBYNv(T=1?n#yk&<1`o-9W&3f4;|H%I#ut@e_T16bKIOR`ew@U<;^tOjem|S zFW?_us&+vBeXYnL!mwHa`gzC=IT}BZ;;p2#&6+bxlXpH2=3#%W%7vZ^KUnJmvBkcxw2%Q0%tQ(Dk$Y++X6u?0#(Yoi* zcGV+vyT%V;XiJWVpQ+dJ@;Pl!%iL=vqU^qt*4p43=XcyQpj zG;g!oeUdG>KhMNw8OaWAS-U4=Y;+EEqfB6UgrTv>yXZsZ@?RMFJbiYs&5@b!&5Ax* zd1rCoJ^<)KmP2GwG;whVeziw>{jSY7CM2SwH|o%_c*R{(A~>8tCZv6UKM~RDJuWCl z)hRD5k5<7x7k8pckQBX)7Xproq6i6kq~v2>QN%|pR%8k&j0qH9pqesRsL-Q2W+Ed8 z#Cxdb#O2zX(Z+9W@_xA7#Typ1PA)mpR2>rywl}H6? zpv!xblRG`}80l)>zZ5A{x>4P#wX|39&L{Z6u7vxr9R)ghUypP@7DIF2d49#m_Y)h! zm0v9y(F#}}#zn*a9E=oSlqDdz$8TG?n-q1uRWGVhd9qj8;4)c-Pi8S>_!WZ?>Xyuw z?de0!mipad-0f~!GGg#kqo<-sZcIQ?hOM^u-R`y6Mhn^3wg(a*ayJVuUqCE9s_Cw8 z8>yJ6nAty&?8XZh76bs~W`e5fgB|1OmsnIHgyqFB~5BQLS{4lSg^-s&NFYe(9UPbFmiEj zwe5UQPX1wAga>~B(_Tk2C)&bChU?kW)62a;$Nq@XqqDVs$oj!vpK9-vTH^L?|A^7! z*=)kxSW(RNg0D~QtxBnSt&m?nd46+yQY9xEi5pvh9K8)sO`g#;Ms`?m+C0lK{!)r{~@T92pYcav{k?o-)Qn8&;_W z7!(sb3!#&`k`XS_eE00_o@YJ9jD{VI^&u@6xyC*$W}_0T;;<*f`6Z`7v@J;+q$v@w zPEsBm5XCcz>$(NA9`Mg0-$VwK5C|(y=#KbiZ$11s*{PBoI;KYok4YbvU&rI40YdFu z?}J{5>O`-Di|5SrP!}yKxCg1wu&`s z&vvkGjqS{&%?88Vyx<0RyoLXr#peEGdAIZw)Tbet?Bbo9oT(=aB#l)9KyD#ybF1>( z;+w!Av*~PE_U4lHJ;VVU<6_;QcUQRSeF**W_qoV>p30lse5Rf8AGf7ypV-qRp@CCG z1uHVq2y)Qi)3+gE=SN~f{2n9|^i8^@rRda3hO5P%*yN)PWx4hmz<{8zcX1`z>z;4u zK?n+JG^kQNP2n$4Z(*Xwwh6*~bU(VeZNxbeoeTfjEK8_D57J=bV)a(AXZBL9K<+)U zayI3oU3q~mX>327F9&&PI+LtE-Fdyxhv6+?aJ97jgD@hKndASIHX0st>6MTeytA9M z4*R@7oomJn_R?&?*EYzrmifw|3X-&_eW!G<(f&wEtt@*aX?0pw>XK(8wm(P^UGwws zskhSV;eXBHP1vBw{VNv{LObvB}6=GN|eL_wla4Ddh-s6zw(I(>@e_3udSRl1yjI6WBUD}`cylr(^M>6d>=Dw*&9ZgKOzhHmj7CRih1)8YEUL_-LL zQ4u3B7?V4(H`>^w24*kSiZ}o=jHYXRgDDe8XTR7_UxvynH{LbzMLM3{>s!W?Q3>Un ziSxzn7xT{X*pBda=t2`2q4rPcWn2D*{#L!NP#E+;IuoZDP~{PidJV_nvahLlFs`&nG; zsOTZ1%Y_V5q;xJbQ`e93`Pxgm#z)F@oeg^k8Lq`QcbBJEf`3>B$)xfN|A;rQ<5xm) zApa9IeGwpa%k2|jrfu2x6O$(_qitD7sl>UEy!rNV`-gUGj00Y)_g$jLEu&n7p&iG* zhY_okj|(%9qex+Jx?NEMgaYHLV=jl?9*1$hKve#_cj>g8p-kMBhO>C?cO73_xQr7n zY=x&n4c58_&R?fTKdP0cF3B#AwtEFY!CaA9H*Sj_^?Yl-cw+3kS{GzZ!=F(?kL_za zZ?oSgGW)Zpt>QqGcg4dF(HU%wbt~_Eq~QyrSQdaL%>Zh@XP2YW`m@mx-PIFPjxSPn z!4I81l_B_Uc4Fm$IVB)ur%YFi6Vmx^mkJc%viw0L4FzZCh>(&|5HoQR?RYBKLt}uv z=*zBTR~O8x7#YgH6zoEk2#I>dt*=NpeLIz|0k2p z4YinCFv)|3XgFC?D}Q-~N8OxOCy%Gr#0Fo3VeS#eMcyKT8(CCboIjvnkrNucy|v|s ztL*}XNg*o&40*=w9?|WF{#$an-XW4)UgCX$U3+!m9#SeOk(#OU8c}V4IWqR}aoCf7 zC7dwo=4<(0#tHr^BDQgxo<{_@(7D{l2M-nNsW)HPv|Mxv@G7)3{V7>44n?2Sm;DoL z6V3sF1;SaVo*v2n4rfc(219XF??>k~IhdlE4=;AFX_)L66enVvUHfVVx~k%7J-hMqje`6I}&>>m1ICOv5Lepq*jm2kjV;?UU`8xKP)Hgg!x{tcHOS|j`+;!y*8Y1CjlfL+l6 zMfk1WePUQmJ1^3n{%qyN7NiX&jmRwZQEoF+i_PLRr}<0YTiBlC;L7c~`DrAXjYPu_ zjPJ*O`im*cl0@XVdw1%kY`||IPxWKVZs}rawGBy?Cyd z*88-u-K;03J-D|TvK|6EjMrw;I8yH!!& zb@p2HYa^H*Lq(Tz1K+$y!XDWAQ3RL~bc9D0-?~pw?BsJB_a4%jzKK;+kQ3}W$ z(_elvHvY_}K4VbtDD03{ayq>Ges#E@dT+ki=T~qo?z!`v>DkOwV4|_(xqfazd7?tR zyiASq9wu%3BD3%67!>iKSC~?LE&%VA|YKVY0jUOOGYo~n_N6YVu=gY5)SCC2f4o0o8rd-HglPM{_;M?U_1~9dp=uq?ZhgOHP!9)Zt z^KcOe{HyQ=ubFqcTq;ruA2L>f9a?L8K~d|Mz7n2Ha+fK z`i0rpff&l-w7iDglYNg+okNlb{QrtUrF}=`DIk|!GCcunL^HG(4U&?R>TaK{E!3HX zV_%V-*e?y^i>uK&}2>1V_0Zc>sLyc{qaNFT-)FC zN}fH(LjS-5@l(^cqt3zMY^yAKU)i12ZQdOj&LRKRvrLVMi~{DK1|*i>FSW1#JaOlB zc~APpd2H~3fC=x^n^~)VfkCq|e0bhk&7-Y&`J;<&H;}6__b?mW+x?)7b@D1fYb0l>wd_Y=20RjnTpXw;hff@dOZ$pg>KP_pNoGHAH^ccYk&7&U&tZ5-IIgr%k4^B5!{*siB5mQJmja719Qpg*+KEZ zPQ1^WRcUA=`e*##Br5ZYB2%-JRZPG^#e-7w2YB}8-+8Al7*@TY+|ASecXuSUz~H_*3eF5xDCqqU ztaj%AAizHJCtLHeky;-DVSi33hBSDizJ6VcgoZ=jh;|EitZp8Rn%wn#&9o-dsVRtM zmkmneta__3N!KZ%eJO-&pY%DuPA@)e6uc}|YpQ|gXjg*cFSWaZ(6i4TPtCU8EmhP1 ziRXj8;sp0l-0y_$*&QDm?YtCe21rt>6$@xny6(5=-u~x}8)3LqgGDgirbA$DB zv@gg={aZ4Md;Pt|&bFIRC|9N`Jc>rf*Eu~sBAR!ntYxq|3Yzf4r{CtUYce^TQ5T|SVJ#IYk7k9XL&R@ zWawB57`Q61qscB$4-1goh!esYxG!WwvmMjr-jTQnH6$GPh-bHr;T;L7 zz`4vt^teQ^eQ-3#ZMj$xt5kp0b7A0z)c(}PE#nC@Nh|QRC%f~QS9U6kdA$gKytC-` z{xre$(@TUesC$Ji@LWG#Oe+`vqs<+YNH^K+3QgI0`I>T;TGEzsbn zGnjATct-Aeyu$OaoM21PR{+2QdA*AufPqJaU#hEZrA_oI(^U1l$U+C9ofJ=e?a@~gMd>H}Y z@nru0T@jzVk6N|Q;jHhxcJ=>T5hGv+-17SJD9IufN9hZQhOPt9k&D40wWIlb-2Z=N zzGl-tEG9G7F0N2q4$z2*0?|ui=Pnk&E*kmxm96#`B73x1AhW29S{&ZqM|9&xb2!tp z7M6lJPRDjV)_8N zHJ#p`g0g>xelJJmFtht>t4~Uyilso6#G;`wNb(0kD-c+ zG-2TRLR3<_qO8M@HJb7NvO8Z3^3J#73_VQiTG=Bp!Y+s)yuarLEp9~C1LN7g> zL(YGx*1Ug@gcpf|pGR}@ucwqC_7QFUslt;s{Ind09=6UBN<@c%b>Qa01^(?)@5gui zTdYRM#}5VQG$g?EJg3vYPXK>B5DCRsS6|UPG1%Zbp@`kW@1MQmw*Q^ zqv8AYEtr)CDCwk}8`CQbbk(Nhqd>YX9+PXqmWfvOp}XnSM7kACrqlRjbn?^k7u7lt z-}3A?CdgCS0$p>>y^{-;vNbr^#GPG6F>T{XUWD}ird#cCvdxSQ7Au=lKlmkl86A+7 z1GjMS(y2=Aox8v7t*sakfl5G1-Pc=(O1eW>gxqcjyCY`EU@#aE=g0yax#w^PQkp1$ zJ$^<@Y-;BQ+$;`hGIRxXyI=tgl!uFi*#9z#;EpWOP)6O3ca8^BZ~zzfyH`4Rt_DvM z2td7h+oukpK?%M<6beTa|M1<4VY40roj$(zpm%KR@LcgH?s1#tr+4k6DVmRAO^ME= zC%3tQoD)H#{}*j-)9>vYc7J9>N7|)k>t#lZR%d(DD2@*zYJ1MR?@88(!5qW_IRT zY2g_`wPizx7m}Yyoa3S{DJj`!BXygllm9@eKxrd+8@?wN5`=|^6JCz{rSq~Cbj_;F zeyh^a#$Q476q@s^wh=82cti!Dt%ZD!0A+i{MPT_EIfRLhu}j&jE$wL}Weq58G7A^HEHzi-YaqlQ?=b>3rBuKl-`~ zRxJ&JVV2}L=g#&$euFE|2ZrmCt_+$Dn52Y$r%Gwk4s#+9QGh!W&;dBG1_Clc zZ(=%t-m0r)H2{u$6_h#(+`#~vA#eZYuaBj8TzGKIP&~|y;B6hn6d*5M~le6 z>D+*$hzz-9$pW>^0#77Es{hg+c-RKumqiqh#B3|f_0ZAYd7MEpS}n1*+$>p~ou3Oj zIO_pJnhnkPMbzU?^QGG9KLf$ED;f&h1iHx2llEcWwXfSr14zN&MAl4j_S z%CzL>%V8*f+19M6L9|gww4&6BO60j>a$ZoiOHZo zli&rnXYPp#qO&Qf6Zh5%6X1VI6V_d8as9+>_#U9c=De5aGh6B+DNrZ!X&;x-#CY;N z;kxwaBZI;ams=nCyy*WPixx=l-qC45kMqVRMIMRuwS&iW+JetL+Kws}hi9Im7Q`J? zKwT_y?B6n>U02dZTL(UkXuF~bH?v7%A`h-SD{-SA^yKo~nLLVq_%}6ithcz-31@+t z?F{UfR%JqIJ~Q|~Ac7CL)VqP4+1FdQ5Z~uV-;HkH-O*X*p+9t?fR2LDFH5;3*1_X) z0=a#{J+a;0z;iVTiH?Ej!+e9jcaglCh$^U$;=m&{1`Fung#o@1l;>z<^jUV?fMC|k zVkcxTuMFyi-+3>wvkk~vRf!oL$Sa%~`Zz?v8LdKox zw2V$3TO0pN+N}e{s7_W__(~Ul;`Mr7W}Btjn>CqRd*xt8Z_}QmsLs?g;?113=ql{e zN0FG65z#*J8xS=OzN{>(+vY4`*n`?6d45D6=c13s~^Lc+pK!0dwu zgBMbQX4<>Gz~jsAnGBtVcOK^f52M2jFev*0i#vPpJixw5t^*CMK75-NU!bL>2Us4S z;l>(b1|J1_&Y{Y+jxE>#2l1A~b`$n&)!FASu6z@#?KL%{8>@W1^-`Pt^ZLVvmrDDy zo1wKA?b;awFbL33q9nCc&YT+W_@w!M3^pK3r{B#?{mhJ|xrzEuiH-BXPDvl|=t{q5a>@N7ld?jXu~N#Na_htzRJC;NyG1 z(z0WBG=F_Uz42YINvrm5kF52>5AlILbjgDm-nBd@|J8ELu|9W)|7`}k}lnswrwx&mMv0Y{Pg@_*saXkwPYQsd>) zA>#JvrHh#UQXE+mg`y~O?RgdaXCJ8Hw#DXmv3b+rp#FQk$bh)0fT-5&C7tr#KL83y z^2x&4o5MihuhaQ*Vn7*N0Sp=gE(0@DXlL#Om~SbFR|5W6!l(6NhS}7=fUVDJ(}BzS zBI|!EAgplMldW+r@In9{@y6Y8b{I)b@V|PmGkwyxNTNbUh+kj`qi!kAD|Wh{5K8d z-&*NI1&+>K-@)ZLxF0YC`VROujimD6ULWSNaqY(MslQbL$)G>qR6=RU#|;orl~>&u z|0@Bt5j9&}(BpT^Kt8~0mI0{kD|0B#MN;oPea{bhY*ML&?f&yfBW3h|QX4kt&7G$^ zyY)#iUz2pW-4Q6K6KN3YV1o|uDIOoYXUp*GI1HeHb)IDZwWZ+2<6D|_!uKDpPxVQ+ zfvPD;pOf!PdMy4RnnRO`ey4vzRL=&fa)dM8wXi`k^RIe=qcht3Z*c<9sED@D@}udc zT3}MNLlqq;FRzZ~Q)t(Voc1`chXaZ*JT^%?cNT|TX~*ovfZ;I?;CSu^vcd3q^H*5m z%jZWNu0_DR&&eBn1@^`SGblVz!pCO_kY38`OBf%vjMmJsp@t3m#|G4~`{)nrkNpZAJ&w zk)d)uQ!=lc=Y2*KSrL0DS5-a1kk~h*f(tmXmKtEM2SiS{AJ-OE){0fCMBxNHIC^?| z35t$P4y0wR0H>V%TE=?39t7Es`awr4CE(#hx7Hjw7W!aM{y5rq0#^WP=qV~ zanp;d^Okyz{OI-V7p}NR1L}9AuX6v@j<{qM_GA_2@k@8*?eIug@2uB;7#!}7m-mTR z#qqj6t5zN8I9h7Lj{+6nl+k7-#Do<*hma;}Z(k*aEI~}@>gt;55L#niy$_ljLk%U4 z(fnU!y#-j5YtTL}2nwjgQql`79ZEN>^n!$R3rKfMNh#8`goKoIhae?_0!m4PG)O2& zxAgy6)boA6?|%*#*Etu<%lpnd&&)kD_soc0uGa_j2;>ZwDRIO!Bx`Mn2IDa7W^3Dy zj$XU66tMvtDA(S<-brBm`F56p(0uyS-pr?m_K>v4o zdHT)DvqCJpP6#N;qsk2=s@vn3uH?&xk6)4#K{}eGoP15AIQPTc2dZQNl^H=kmt1$T zo#_L53GZ4ztUh6l-sM0H?o=AnI3g;-3M@o95k0~bS~mCW#_{e+megCDRawnZiPAJ0 ztuT4Dw$j=H0t!4)i#L)_o67mOPVu~4(At_t-0I4!B(9vl8!8@W+u#GGpc@vmKFP{# zx)#FYc~1~{U^_aety!`C^$EhMI!DwR+DecO<+U1U&FAhl*uYYM6N8^>h4b)^kd zka}<8--r71vtn6h2;-sLnktFf`mFO4G%8`8IW!f0iQOOQFL#d14@g5pe)d}V$w%g9 zt(mO+w&MFuI@+S&4nle8f_BFc_!o{eAvteLIyR{dg$4KP_)!YEo@5!uvpV=7>_wcY z5n(%?MV3@`8}uR|@`I!Ip<;RA1qhL*uVI9Q?T)QI3s_YwOMv{*pNU69lSxGS6tler zExreYe@?xs_7iXP7{ol`C9p5}7DeFq$w0*B#?p?y3N|=}Q-~16x-55_Yg++Z6CNLY zWxgD+>}GJ-b+Y&ibit;+(0wJJQ~UWoD!5V4tGcpPaGDl7rZZ%qImgbspYGVrKE>Tg z)A&fQc|dXT<97qTotP7Kj6g~FWa9OqJH{6KcxlqGqDOFVFR!LN86>llQCApItmvBm zbst55PhcpR{aji88USw&)7NC9(0`4-^^IQo4HhO6ajk-xN`+us%)){L@PuROy#|fY z*5mKVCRn&yWk}`MS(?ZvSZ_7J@l?kUzdUnA^*kQi?5_;|21DwoZiKIev<*}ngC9dW zNzfHvp|5?rK0Dj0u9EVO8*=2a@I#qBfl;N;fTzLFESjQ!_cl7&*O-B(xN*O2^j$#3 zFF7H8R0yY@fj5wnU9U~ z{2eVs+GcU~?e^C#B~p)+;LbguT>ruey(SW!cfI^`V`{&^l1O4&geiyoudyR!E_uD#w?IS@O>rYdn~srzsYDBrE^DHsly$EvirQ|aGAek6~PYqKhq%skqkertW4 zw&gg0T8w*+uAC+X)X`@$1M;q5-60L0?AylW!Ca3CyPGcaOT_CU9%M>1EFt-TBWtUo zbrAC~OnwLl^x`Zm7|nO7T6?bF0(%t%puu3E#67R{ z!S@N+b}WzrfyX79^utDKZ@fD;NwPCGd9iIhJ=SMPWW4iw%bK9?3Aw^4saAUqUGQ#x z$sF|_{V*M}hH+aO>xZLgiH}6U4_W0sJi)>z?lpxuk0o9zH|^>WT)m7ifC(cX+%qb@ z2(=nSzst|=`rMSAvDR^px?k*w5)5udHlTYY4P!a^36dOSOSQhUZ}q=K>xd`_NBr;u zyWIKMQMu7XY)Wxo9fBBMw)qdwPFWWo^(1k;&C>_Q7*m7|1S!zKhlE(SHT;hZn{JF0 zqPeZU`CyR%(?YgCgj5wd&OOvl{~`TUSK`Lg(&!%`y3C9S%FD$EWbqf%A7XMWD;5>x z1#Prh0?CQb9`$~6ZX(`#A@{vEPgOoL-vty&{0`Kxn2kD@3awA%w0_DZNyDrs5(lu@ zr+sm~_LgKLz+r(tc9?l=0e~@A(_GrS8aZDa* zfJ*Id4(T5saZS z_-^Q+0WW+?m>?Dcfe;PKx6>2ERf{;EcWx3GpQHYaaY}0nPqC-lS-lPMy@j|?$s}a` zdqOst^B}vHJcO{zsO#Wcj1BWzruaof^Y2Q^en2}r!|%yeCv{wC$x>bFUwa;I9-i&n zMdz`z;IZd`5hV4=dG!|+!;eQG9xc=@dV%PnCo8KtksAYYkb!MLP${1@+#8F7jl=_4 zveq+3x^jI6ZL?1vtPIpsHu6_wO?gzCTuFchB?0PH&5~<7^4hh-Jp%cBXe2FqYYwK~s#&3}wW z>pp1XGSDsd!LKy;GiV@)R{XZLpdf$ujp+iuYAJp=X;;jf;N5K1KvCw{Z#TVMeA=c& z3Kq~+42BtCFgiwlv^cy4CkgHwPq*Y2J#=6W(_!!nQE_JdYAn_y)^L!_Btv(9abY4r zfY9^X(1K~*F#1%seV4cJk?f=PePrK_;DyfBlI?p)(iwdXxV1#_us0dSyQ_Q%kj{r2 zU^koDE4^>!@cxX$2Ia_(T(Im1Sam6(MueFb;|zg(3XKVum1_KvcJy7aV_Z`nxp-)w zWv8q8&}NdTv&^5~M-FO&;sW4d!myNASb9&S6c74;$by@ zA2uI-{hf4yy}fuGWNWq_#Nq>AB6r>8QD+2d^1{SFYwo4nId1hHCof^`M zJl$*uywyZ{K|vjsdeFhh5;T+(BW}~ut%EcKre+-T$;)7AnIg%d&6(D85qKOO{P_)d z+sk52C_t70Kk)^94(-A{-F}U*4>{Z^R0v_*j^ArOQ)xGD+=wrh^R;}am7tWYVdX#a}u zfRm*FFJfIj$5-v$*$HWHUz=Xm$3G}DEt0fSui&-OWd3FSjr$ZJ1vomVB1jzk8YnZo z)%p3SLZU{U$L#DO)mUL+ArTj1%7lZj#$iTSMzvV0EcF?ixwS6cy!9SEJNxa-XNM^W z3$VUM4)8Sx-g~(waN19#?~aw}S6L0xF*tNB#z;7LP%g<`0K@z7d23g^%{x|24K9Rq z-Wxnewxg}N7zV{8W3Zb!+&n1H?c&TU^x$zIdQE6ud-z;D+=D;|jE@NSgzZF!lin-G zk~8&M`bEJmYJc8YZX0@lBIdQO4i`JsI6nUMEH_ln#f#Nq0yE>eyX_^Cex{)n+C-Ue{1Jktwnh#7a&W`K74~mg; z)YAUDT})PUADa1;j8PHeO~Ub7TT*_z@c=OA^?PgJbp?AsHBoy(PY|o1M^XCe<3^44 zPDZugcm!$B70R{+{;pJyno6QPgElQjTCn)>+$7}TPyG*+I&Wpi80R}k-)d=T_5C=B zBlZ^kHiKP%dgc@!YUUYdfE^XU>}JUs0u;NSyRnFL<(P47Z8tj`g;_%<+R?|2UpEi633w5gTZIbT2VIDF+r*hm_ z+9VY}l|yNxH~oCq+Lt`VV*RFiLF09uEM=-MF@2zdqK=X0G6 zb|IGB+P8`?J0U8k03r@HxLTT*AFd9|IVyR``%HVCJ7L`?x#G#nVzdjdbP|m2%1_5d z)~ML7)_oF6YU9gQ%EdP0UeZe{silc85hPgoZ~ zSr*84&B2ql8^&FbEJ@7y+HV2V4e30Gw6wH9AMqyN*-Zp>c6KI&Prb8GqhqbL8vG=5 z$CN{hz^UNagqg1ES+=t`STCReDv_a~SYS8m4X~<_87W6U;~`GDG6D)Jm^M6jX_S>TN0@f#V36_t965#LcoF2@7x{D zqEGatejR$kO)R}3j65=)&&GpqXtoc&JznH{zDgMR?X}R`O;a(7rmkXCy(9qRH5@7~t|!|HUH#NHh)Jl6gBWQzUQ6Rq`@`4pMRJkEMb*5uktk2#(TC zynshJBQ&fNDIV+NadfO*J{-r#$9Y%^bnLo>{fr`~9~%#pT9KhBpEN?UM6B*r=>B?LyP<13h&(t#o&=Hth6M2c?)R1ZD| zwcq0*N&w^bBMLGd0>04T$ycc&KY>LExXu*EU>(?Zz*-zTf>yf2=IJnCH&W`u5}s_$ zHg}i!G8UP%y<~X<@_`!$xpH=lD1;vi_khIKZ>HA9y?+^8R5mD0@5Q0f$;Bb+j*~2I zXgX4it`v6Z6uv|=nEV<0F^p`FJ9=JARq({ zAg;rIS^{#N(g2X!RRma~BytVglM=uYzKgsc<;thJ^hRNPJB4P5FQx=Xo^!s}4I0qF zMjPdHN_`qd(5nPRQ$z`cu6^r|I?2N=L97ev){kb$BdIs)XB<;D-nA=1xU7Lh1UG%-eu}1tJyEM;?HOPG;z_RTCb6$4#hgB?Bk`!DmnMklt zkG53ewWS%3KY*dWI<+>MFI#8*PVHrCkhc2ePt*|$Wg2CA#KE^fVkYl0@2zTg{7=z$ zcE=?4Q%S8yb1jRF*V~ig0|>E0tzeRI9-#1?KawHfXw-r84tTWgNjk}UGb2LIPCCS> zr?0;sxa~CGn=&HxRuBhEo;bQ#!`bMjW&aJ2tsAMlXc_*qUI*1zuKKMY#(cNQL? zw_M;*iehP$;k@5@`9TDVcMjHf6o@)~lgpR8fQ46jePg^R0IUIJDw$Q+*m!1qaHE1A@cm!K=*pkLzhET>%^U?jcp2w_)x zRRCQ=b;uzQ)3JjYP`f2M(p}(GaV;?KTQ*!B3RrhY%+ce^vU(IP2DG+!r!53;5zG$m zF6ng-qqRr|Lzv@b1&2(d;;66A0IHA#m5`GbglM-f2rDbPNN6=T{t{O6|1c=N~C z=L#T}?eP9s4<$ z&TTCW#oLv62ytGGlcRow>RW2kVw%uv!Uq5w!Fmi#*v>-4J=TM+=a#NqQM?2Q=>ov@ zrd2xSG;DkU&OJ*e0~1r4{b?8>Q^MK~Dpu0KRkIN@=?#d77g~n_PnB^8m~PI^SwX=spmKGWWg4XQk&BD@%Y*|b`#ZLrycA) zOcfgivwQGkr$h*%YVOJjI2J4Qzq(guwPdC3PKMUxhH!sEN|@8g={V48zFnOSP=|wZ z0PrDr_nWRT2yFfX1_Tnoq8>mY5O8u37)NhdSD|Ncivz#WSG4xBcP~vG|%q;{HZro3om(Jv<6h4-o8$$*AN+C~;<$v76ev zKnfT^2sy7+79|~Jmo1p5ZO|~~FzgW~^C~tr04&fH?lsK9da&DKu%MTSgvRpc>b|vL zxm4!~cPuEwR9gp`YUoe1gr(D?Ss%QYS|b-17vcD*1q2j&UrK=t`WT^F^OpVjp9PSf za>%83y#tf-FLKwFhxiF%m;5SiwF54Vs(qw&Kpi!J7Nud-|1>NZJG5mJPXgmjPGG)L z;Li-#?|paq&H+O`e2vceRz8(LEywd4UP{=CshFXs1KZOXI`nj9M$P14@wPTpMn;A< zGldCl@_Ac9hBppaMy+(+9n8VU7kIi0cLM7JEdqO<+f9V|0q%N4)qB|CD^YdhH`WjnIlH-%w;4!wAn zOD*PY48VBzrIjm*?9deF~+0Lm5<5V^3H^&SZy>9$T;qyr_Pv5o9kh+ofK+t82P&>v&ZP+jGmS zL~#z5smX=d>K&9`uOKc8>b4-~fWyQi>1q{IjYCBo0B1LB{0#tRPGg zAPDKvV$H<1k+u`Nl^Z2BfpA$mBnglC%0PDa^HqnLMicO3mb?nXCM|+Q*ldyx6WOM? zKn)C{DCi}~Y<>n-bfNP5T_iS~TNyr^LT6wL4~fqX=wmbiB@2a_-TTe(K=f{*OurKr zlyx^}YN_hsbPet{=L|yDo5X2zgykXYTi70pAB=vEIwd*%cpDZ=fP#{ta5MJ1DJOG4 zE;7p|%@&3*64~YD1#=fN^we8*k<)$_wLEjadr=GAZou^2mGFB%+R=6CqI=m<>;|Un zk@Tl=J-HzpO&JDoI{e~sn!-Y(@(V^?3E3FxGBW~OVcn-m_Z~Y8dXJjYqgX~zbqpkO z2uGF1gbBu0=Gnp-evtp8V;q45~7 zyX*U_zq*=zuBVGwRzpcem8i+BKOrp|U`;bh3}Hpt*b17(8VIj@nO&QvjOPTvCq~YU zkTsGZ1@-92t(~KCnHF#mraFI@jmgu~bsa%&Nm1ipfQnCe01WRPh^$xSW?pkTqc&09bhywllK!sFu^@x?Guq2HNP|f~za)jF zBUmg~An>($gyGVH+^<9Qu<2$gXh6145(J{56CK!{mzs!t_UB*SH%cl#d;&`q20(#a zMzk)W{cAz!LCfWZ2wVGgPHeq8 zlt){>7J1qPw)KKN2amck<3{60qby}XR`fP@=+xt{pX811JyuRL+;o-ZE!N@EdvkgL zW0fj_$tr*s_)kq1$W=s^K7J@l00TF9-ypNcVUpL?5Am^ZarteFpm=BPL4{x+RI;G6 zX?TQSS>7b>F8E&-ynFt?zuz?f(onAedF8%fqpVo%YB=3zPGoFoKx02$&n5NB>vIBF zb)CH)E7F({h!9Ub7*s9r!l z+g-dRyqA?ZO@1mH_fZ1BJAd$P^uwQ$IyJVc z^w45qr4leRsIOQAdYZuXNmtyRYwB+;7;m{O4lkL+WEYzKHZmAVR;czdl4fx3fT0Hz z(L2t*cX;?!tcej1O~x$Y=G^3_x7>jSwN3YR#j50D+P>fEvB2NgYCuZz*1ipEu{s#S z@=^A|Q21a}(_>Q@Eqh9X_HwE>OltBc9r|6Ss@a4~A5#QVh3Us4Y#BGaU;M8^x}rPz zo>-;OUp7g)hnzyPd-9W0@ezsJ3x3@?K@=ut6Yo-T>=v z8(*WO(|I0%RSEV?puZywjkwMm*Qv_O%~_B&*)1hH@h}R?#phL0)%-@HQ^Ei+76&7D zOqLpM&(2zxCGy5~wZAo3W3}Ln$KZy%h{U9GEfi~Iqd0Vd$bfGCSa2~tAmtZRdKO)lgDvjF#Ohd~e_q=+yGj|Sw8 zux8ltC0O;va()UlN-2ZYK0aBIPSx7EwLNMMU^m+$R+JcIpsYCc&evn_%1*A0mIPD@ zSYyXovXgYp85wQHtSuqtryNK9!hdYQuDc)I?4brBt6;)48Cb}I`S*Xv?-_;}&z8?t z(LscMleW=LB#p&{9Hqj%lllVkyCtM^Diwu_;}?)Zg9!1Xot>(DD?p=&z-WJ)!{Usf z#wBB%pek+nM0ks_N(U(wB<16c6y&UPoYTThIhyoo(_ZYiH(|uvQcPkmqFlJ>87cAA zscA8u|0xsQOiQ#lheQP9G|Vyk_Q9AxIJw2_`KpPlS~1a^-&g1$=i*TmZbx>j6c3*n9cUExPGxhzZ(2K4GbQDj;zX%B{!up@@@!Bk=v zPU;l-R(0$b!Fi++(!7oS5@G~l+ORa0+hviW2jUMtbt$D6b0g<0YHhrJ5x4sR5=1>Y zfSKj|rgr!KhC$mE7!tQ6OTp(XA11EndqX`r@SQ4&137MYLY zafn$H=tR4&^LgPj;-xNYiu9hR_ftI!7BoH+ZI+)v*eMur<9c$4gz-H<_;5UuJb!3P zQLhjl6QezCKt{bAJ3;p71o(q5$i^GEV{z~9zi8fnRocQTF&s)vVS~b9>GvPEs?u6<2 zRrST!dzvoqw-c%`a~}Ae?Z|~{#wRI*@zto{bF~SGO_7119-x19tk52zuf?4MpXAw7 z?!@b`-^3^kv11*a?J4p7f^io|zFz`QXFCI7hEkv`Hei6<_$cje|CN{E%hukM*4bv8 ziN()owYDPp)JZn}=)3q+1IUlFpW??%t4yn5ykvg&OcStFcA@%4cVZ}oT*hfI6i4{b zup8*{UPa52lDyntlQWDh0r}Fdb9?BbZ$(70YT#k57CR$s zES!?tvgT;aFUcL9JrHAE4*nwqOO`Gl)#4DYN^d=si}($o^`jCk4jCP!(rsM?H(x;+wo^b@1KHl|&XfQ`|MPc%}3=b(;bOeJ4+sV`k5z0yqCiJL)^%#U%$ z0I)#vC})A0GU7R5r~f7=Od~!cph8{R&BusdKso87o+%^8wAe@GINiMVys~l{3$H&S zH1sSqoYks2)5r*Wd$%!oX(>(q;q3>$XG-lIGp)~W(t<6p!f$dMB}c3s=jrTQZEoM*3-@UN8zoW2gz&XSu!yt3x7lq~t)s&bq{=EC z=RUYq#Vr|pB{p<&s_Ym?K=eyrFpph>wbZx9!-A@OIfF()8Hm`HzN)$yY_A3Ba_5tM z>&BOEzehjgtSdooWc`_E;7ZqA*Zoe`Bl+h8P8bq+Ia~d=N4l_3dh{WX_Zfc1rX@uOOL6(`306SJ%tLHp&cT)}&(06Nc`il?4CE=20;kKU zHk*OK%4uB3G&u$zjM{?K}8^kT$Z53j^jCTKnmB_Cd1F zeJcmND2{>KmwhR3nUTn|xYq`EcXludXfkQIxV#966Di{2W<$ff#|_`K8^zukIlYIwOui04nvO97INX?2X_Yh5HCQKmPX3iYZD* z@eUa`n4)q|dx0*jl7Vxboxz!h%^3?g7$Xc`XL$@dB^dXCIPFK@s0No5k}%%E&M#k# z%FKoGu?*_78Y|*|*FZ>rogZmF{eVJxu=!%3T@AlFt7iWBVN7($xci~q{Gt7X;jR4@ zw2_im!RTJB*ej6%ENm3E?aW^xov-;+WRt{okNXBry2up^BiS?Tj{9Dhz@K~ze5qwP zvB9cJof;f$?kgV|5{bl8s=hU~&1^8fZ^^4XG_hpiiV8SEa&BCTa;TBhHmTwsv|`fd zYpT4_vPk|qt?`QF`$zPsQd|L5X(H-ft)t$Rm~}P#zY{=<{;eF1P3xLmOmP&+k;hEa zb+6reRi<$iZwqq&0hB%)F?})JQb8RSmb4QVY8EvP^OAmTdKLCUwr`_9dK6%%hwm(J zbZ0K@q&JoCOM+l7w#>^puNx-c)|US`j;2gswR8K9tEms0aBu@6k#whjLIpUwF0|78 zR;3LwRN_U?D{cR6(Or2GoFx<)4wc(qy3Al?*7a6WsGctJJyyeH!WJ*5I^HYz@T$M2 z1yCXl7+wDWK*5jz27=DtKDJ|+)(v*Dkzao!uU(^AMMyu;z!xaqy?o41dH0qKQOIka z57DJE{6D{@-wLH@GVKWX_&iY9ol=*+2StlT`-zGz4Q5NDu5$mVsSj|!-=#)(D1~$R zD_ku#A42?pUA=8M`Y&8pMalTX=Gyl#@gRomuHQa5Dq}r$!;m)N9J&zV^c*HX>I(3`Q+ zPps0dXz~+N#EX}MT*v& z{m})d+h}q)DxWzHchzK=Y%!RmoE50wb$?%sKTNF3qI6ZN)PM}I1^)i65hzFa>|vP1 z=Q-W(CrZ6g^@mITNJexr9LXQEibfkEvva==ScjiB^<`j=m0ed(M?X6onN#?hy4v*B zjtV6iv*2qoxupY(4S$c>{AUc0O^ps1w&dfy@RP+tS7Ga9d2R5iv3~)`NG3*S*J?~O-{|dbdfAvbne(gT^f}U=EZQE2N9qPs;Rl1Cqipmi-(fa-VER6L+hT9`ht7ut09$Ypwrdp@H^U9{u{U2xkUtEpQmb8hS9m_( zsty+3vtSBrg^!-B{7?ckCi01~qY&NOLqt!^4- zfdYc7B`A--+Q6j%)l+cjjO73~R&m|_9E#UvCAj1xOZ7pBPN`1XTVn``RQngpgbQtX$vFW-H6s`k>kVPeK{ zOa;A<4xcDDTuv^!kWXQqzDnHG_km^9wKYH|7`P5}-}^4+JCJz;Y6m*60$MkwS|@8; z%MERwr14`?``j+n8N?*-XGGQLz$3uW2%x>XW>#ZlN$e62pE>qe(!u2Yyhd@`TD~@a z>vKhHQ;%5L2EUJd0e2m9&M*DRt~;XZ6P*;3i`Yu(TEGG7?r{U9LNs=m z1$AmcO;(;g7FD`K#WkwZyCZRyEq|+Q&B8-TLU z|Gtq#WEi0Eyn*8Y2=+=z&o z*bY`t;lDoyxAx~tFtYXhv^AQlHHibRK{UU1%WdsOjrC$BtL}Z&C+dm7bvgMA?AZT) zekHzjrU@u|eHY7Pqgn;;{YTHdL82>#`69Ea$4@r9nQQ~X02f}oj$~OP<_ZOh2SZAjdtc>YmCGWDGpNw)Bxh${S653vw{r~l+&2v?4{(^EiRZ{)YsJ{Td0$i~~ zJ+p8SK1b-2(Jdzjqn?O6E0XH%);BA?xZ!Ww_}7zQ8}P9ehuQ#lC*cN^RI&Y+*ej zPaI=20_FYl9Qr_@)roia{@y6~Gbyd#+gd>qYcR*=5gHnJ{S}T$5te9wt>8ZgPb~`- zCKq(NHeP!RFvCUeQR7RC78_2r8CN#_wKpK)g%vOJH{o>+dt7$@>{w$h_qQi#5guhl zl*TK(dRNB&Ie9B6MtR) zAuYJ7IHfMmKdA*!fwvWFlR`w4OTyfQ)gg>HOTL8v*Rrz(N$1cGkz#0k+<>Jdd6h;wJ13To9FtEBKBUH?a{ zoYL@VV_VL(H*ZE>D!$IK_&PF&wa^uCg~;Vwxq#mmddUVL@!u;~5(Zvg0QM3=qN!yUv_gHA(81czdZQO|xoFsa$9_flk2vs@t{Z&Q}X|q5+`Exs;Xj&wZ$)i zZ6SfbJ!_Z2xoRx!Ng3AkZu7W0SU#u(q0wU3?B5q0@P%NY9p**+^eM8vHn_g2M4xo~ z=tvN@er3`AehAq9jDY$Q!a;dlg4oR&^|aUbXa=$ZZM-ycUT!vhym58qDME1NL#bZe zKdvY?z!Lxd@+PNT^c5}f>{#qMrVrxEm;O5f-V9E^Gyvy$Eoo9A%;%(&{ueF9M;xL7 z+qH?SplMqaSQb&o-iarF1?J>~oYptoIcf zU6AMZvM%yI(1^%^Fu2WP^8*wKE^&=0&XeFj33+0meVv&J1FI5;;gUg<_Rfj4l1N78 z;}dTTTo2^`qfP6|P?&U-Bg{FvFI|!2;&Nlg%T7a+=JO8C>4yYTz3#(mY-xH~Is%L+ z|N3mZ2l({NZwNXVY)qXHFFI))3axmfyxvE7qydq%x0zUj2>UU7zI!OX|UEUt7 zecdl>^z5fYVej7Hs>lVkvX7|o5pHvn&C znlFyuU%7m)Gwe{$vb8TS_213Gvs5wOIpT%L;>6%m;l%iM;`7++t;xsyxw|t6xVwte zE~~$P8yW?QG9d-;_xe|7bL7mjM&8+z-|H^=Uw39@3Jw#!R1P4xJOv5L&H<3>VccYY zC3hu>!fUi@vR*<$;1*FLT#5)F!9LVx866Fl9jkh>qW&uN=ZyU^x5ei`)KgqZSyd4I zBelxT~R>;>s*_hA75{M367%!9|`7!hia@ypheRB6*!Wlu=I9CcCJMA z(Z;RJH-*fUjCaV*KIRP})&sA9 zH!3WiEwj!2Lc&4b6k>pPr9%{$SaH=mr8FbMkb6+>RyHWnV-Q5RS zZ|I1aS#*YeWSUAZu8pwlvphQA*?Ol{wrjE!iT3wn;2@HLVy#qQoOo#Nm)+gY9FeEE zn(u7Gh0kZlD;{%OT0Z_G+aMt7-(CNm!)kfs(?iw5T%ol(yQ0TMx#fo&6c^G(vryxU z6ZXbHH_U(6myZY24O8rA$WNM8v*W~lmmH}?UP{tyQrsaYI-!btd8HEYMsW2Zj_%`s z4;;vgW0eXCMx+)?$4u_OZgF41FIRKXRJ5PWts;Dl^Y7wtXbw6`BB#P9K^s#clJumM z2b<}GoX#$PZdw#g66S~h-$f}22ZR*{KdKoTE7%_mO5j@A-y)e7d}E@o)dTh%RkbH%#C>Jq<^N_~FG^ zb5jayQq93HV!z&Sv;k~zDJTNOhEJ!!AQn5&)WPqyyVRN^alAz6^F5e!{rz6w+miCe z?(pz%NTY&`OtAgz`w2invdzoO8w1(J7tv3^;b3M*L(lotI2TC*o|KsgsETC6a3ta1 z$Hp;gZTWy?O{|f*sp$O}&`BF&i}dS)eqO&87(}YmmVY@imN$ z92giF#uQgQAohSrzz_N!e83RZ^=`I`+oHQEmKjx34ozQu?7KOwjAJ_pRgKY<~q|*2vw}f`BTN@d%~| z_DJw3^~yM1y@mq}8yiUzH7~D~LHj{lcI#hFE5!`Y=CZFBVnJ20&}(aJ$MTgZ91BD6 zXx@QrwWZlOn6y{}rUQ&$#}hT)``NGZiyRC!tpTG(z8!APP`v#lzc^m+oMY%VC}VVg z9}v>UtqamgJk!CXm_aZcaIu-+Wl4T*uBD|+r)r`714#H*V-qsMM&&$YfyaRq7GecU zJGf>K#?gOVgx>EiJ5_YNaz7A4sI)?FSo44C7X%qdW4KU7Jl>kp&djp?wyoiL5nE*5 z6SQDjiAD4xAnnt}@0~?5iY-(H_UWuw0f*81@y+_h=ob~P^Q~xWV1uh6sC6EKQ8w*j z4qZqF{->J;O`wvp0k{sh(~s4^u~?}c3~|&=FKno5YMS#{_SIBZb5n4eg`!${HfzVh zsjq%KpXpD>7GYpTbe>HfmFwmDc4%lj{|(~k*|G1W{~QNVE@<_L z`-m)l#rsTAm5l{kC8D4rB0@_4^biyAC5xc*7Mc=Q!C}m-y;7yv;rNqx{+E7f{Pw8= zcZ}X20<{n!`J*UE#x_dZ0^ zC{RwPi=L=@5&{;&P4T)L?t4mH{C?83{pnWoIYqH|_!PDWFr~7j+uFp(1vw%frGDDH zwgsPP{pEAFfG+MkxIXxE1XDEv@QV<~4TAstU+fe*tNgY$dxP~3aM0ARW((W{k<{mC z!$;5u5~@Pqu`{-7;jzRkKhYa2KRb3FbqeJg{Njco(8=khCY9G$FH ze0lBd@Hvv~5L|E)&{PTa6jwuN{3_q}3HPk2s2bN=d?jz%A5v@i?`0c#R%`-QB8* z$T7%3{wGW^8`@twWs0vlRl8`oxx|bn9SNAEq4@#%c6-`CtS0%DP7hadV!XffrU~?? zE!YE!JX4=Dm|U|4xX{h6U~v3|#^~PsbG%7RQUm;*M*Yi9F(KPA_BBvE7TfKl_c)q) zIyL8iX+Fldja zM7oHVct#r^RGcnDQLXc2wy3P?sF?_8D@Doj}#t%Vy{)bl-UsaCxWYw`6p zhRNj@KS#Y~%JBT8ufk$9?5=r`n_u}5Z5&tZd+B!=jrK2-4vi`usI>) z!_cQfMsn8zi7<#T_=t!)g!78SiEv~bQ-d}oE)m)EltF(nkp|rk&;NRJROoWiBNYZN zw8IzYRi}jl;nPH3CwhVL{LQ*Vs(0h4p68KbEyV_m@uq@P9%enmJM+5Ov$4?kzLL^M zz}J0d{^U!vn6s=Vak^`tkaJrR7}GY1WKdiOe3_b$&u@O&7zCrXEd}O$-_{)>4_g3t zw0TDih$Zk#v7??V0xdDWT78(;5Ac?veK=63Z5(}YxbK)Pd|$6L|FZ5Tri|VXcBthK zl0UdIpcR3m8s>%!pu9rNe~gG%U}~yUE0JNBh9>dpNowO~Oa)GZ@QVzE{N6c;Z0P2#vpD4<{5|scaqzeje@A<-O zKC8XYhl7fz$*hKNrvF6)Y7sn``SO!0|KQadwsc;tvRiq@?0g5-V6`c~ach9*s|9xq z9KEb^<;(n_Ix!zfP1&mabi&u-SP(X5n{VY=)?}i8PEZtpzauM423H}<6{drxqhDT7 z0FgK6nGlOqst`Nx5WY-XAJaB*T=u?NP<9_9z4 zSS8K)j8q$dceyhr*|=Wa^WJ)aQTeqMSUrFMwcc*H@8KOY}D@Oa`$YFqTGOpm+hfyowOuP#+8ckbfEt5dIG&xs6pA&cbZ4EAef>T*R&L@z*7E z#^8s55}7sp8WPBkKR^!0i6qMJ&6JqCWivv;u49bypXw@NqqmZTGnXXL%eZojz7stJ zr>##xQPCa{qWRrcRPVKV9gJ#kOxjl5AxaT&%mAamI#CP=k-c7?#bU!}2cw6dDICY8 z@Wjk;%15+HbxcyM3lRz!&L+A3X2(;RQR3lfo>Sm#pq+zPg3=i zV+PYQwnkXK<*i8kc~Sp=T!7V%?-Zz}$`@~>j!#@-ozKx6QZCRA@iIFR^}FpV15bp5 zwbpgO@_3p&vTf$Evj9lJs6Kbz0!)!JI#4Zr2Q06M+e)nV;9`E;-d4%#a-aCA-6?=B z%|&bZ4^vzgmJ`{uz163qD20YToF7hk+tRP<2oTQzkg-^vuJP!R2>?a2$F(i_t?DJ^ z<%b#C#s|PLdymQ%R|_vj@|4!skNaJoKXvRDB}Z=nF^FycFK8LxlK*twJ(unVPh#of zKKE?Z(pY(#4_SZVA^ZXhpKiIo*$nnS;R_jLAi?tiT73S-M2kcXl@M~gqZaDJM+}TmD-i==g0nG1FN>DlQGVNjqSr4ASUtQ&WHUoAs z(;sTrjzp62l$MtE>XRZqCgoQ~d;<(0sYEqVn30Ok{CL$#qU3{7DNdNMH*b|X+w+Yf zhW$&z%be3mO(IOQ+Aza_f9MTb0V0&$lS6L7|2au^giel-Mx6YYk_FNwY14NPSX}%I zmpKn{ZCckPk{39z2yG(zpvVcUaY52wfeK1x4LhgI^xlpHR0Ckn(A*J zH$6MB_C}$JSYXbQs!-=&B6!O-hov1O$ zA43NRn3#!ha3XAMhk0(0CNq)WIKEDwkE7HtHqZX|X$JN|zDUU*Zw2iy)}%~$dg<2M5hNL4Fy?#o^( z=hT^He%sl+D<*p>Kx>#CGiQ(pA54sUFp5kS=M;vyBuaTl{fSJ$V-Yn1Fr0mfao_L1 zSC58~DnGw6*nYjxlNifejL=b9Zn%(})<}N2xMbSTo2N-8VmqJZ>};;!#C_juL_p60 zs}M(3;_^Z%TWR05%a7(+bYtVlb*NzzS6U0}V6fPOY?U|tt8>J6{vjLq{1Q+I<=I%C zzYg|FzHkz(`fluC#Kl+4hlBl1pm{grTiWi#q#AYsktkn{b?MGw*wIPX>o#3B&^Tz}EtxRLK<>ZywPuzXo67kn^%uiJ?&i`hy0wsx| z(ySvQ{q_HIJi$A(MhnL5dZ9ciFcCm&btzq}fuz2DleY-|9;Ze_@?(~3KKbE_;6E6O cy5-C3$$iY>LKO;0*T6pr8D;5mN#mgZ4-G)lQ~&?~ literal 0 HcmV?d00001 diff --git a/egs/librispeech/WSASR/figures/sub.png b/egs/librispeech/WSASR/figures/sub.png new file mode 100644 index 0000000000000000000000000000000000000000..5674e9febf3faabfce52a1491adcaddeb066f4fa GIT binary patch literal 15900 zcmb`sWmp`+x~Pl01oz-BgS$g;cXxMp4;nlW+}+*X0t654F2NzVpNXx#&OSfyou_)b zyXv$0s;ax@od^XvaYQ&=I1msJL`exzB@htM3*dDh3>5IY$ujO51O%?pQba^SQbdGE z!O7mt(#8}7L?Yr#GPH`aCR)I`Hy|?@1U6R!wo}p`v>^Ds;GwFDFe*{vJ20WIF?3Zu zTJKPL>+;G;;HsqJ8%aJvD%xmFL28Q5y4DiW(7b^o^{%}=zxls@Yd!Ew<#0DU<8xn6 zg$4w6Quxvp)gEVV%lxl|DB8civ#^EPA zvb@7KnK5GmJBpzmRXLa-k#p;36OiG@lC0EGX|b2MPQ@nF%Mg;p?{7Wyq@)S4%(z(C z`RqoCl+X33-B$_E9N7`f**bQ0QN_2+)DZTDO`MtS*KfQrZ}d2a(}&ohKZuGYp8Ka# z1GAQ18pZ^1B2f<@X{?wfjlm1W63CZN(z=LcA%B?Ncjr)J~J4Y0!3 z`SulSe#W7Q{6cnVB+sJGjHaH3Nj*~x!^!;SNsaP-CWAIAjbnVtax0L9U5~j_6qjOD zpXldRM#M4ginrLU`xAzfiM6v5$+M~V!Ne}$U9Gu~+0d}l4t$5NnGK2Q?KZx=CY_^K z_?TYY8Rd-mLFnlGeh!wCnQiz9f^sV;@eD#oATo@wzsx&Wn*dVo3PCCmt{`$WcXSWa zPsvij->A7{*tJZ-D8vWqQ^p%=>uMnl^2~~>2HI=BX$51Oo4KS~AlCBdxYl`!C2K$l zDied0A)C8Awu9tgf=n%rBuF_GyP)j_cfm>o!2J+Jv;gHZi0Z^EheQz-cN`x&M zpi%SgLJ)2k1hod64%%9fmm21k)#i9u8G1v>XR&0<|PaH;ni*FerzS3R5cJGcgKTNRo(ku2BhkS@5kO+f?Kg!fzrvK$dVw;X6Nv@?EW>`$@RfKvmv4s{)Ls$p6*IWe*!Zo+RT zTnXq5?~Ekq@$1=gK*&IuKsktw{93r3v?b-h<9O5J+`^zEU_hF6By$49Ba5`&V>!X8zPQiVbXrCe=8(fA^ZO6-c1T#Y;r ziOWyh&g31Rb3ezT>X!;yR!>6~`r!)CViXlQKkZHR0D<&LovJ(yBH{CZ7#E%#Vc z_qi_9De))j9?jm&NGC$D;qWHoG!7ol5cVB*wkoUYA`RG=8yZ)dW)(42^HQ%8IF(41 zW7YFAKQ-gBhmyilTQzGX?}`{j4pn4T{faK>m%PVv&1+4DdXM@~W(4Us3gGjMb1*X+ zMYZbrBF3W1(lheP6|K5gLHo4yp-EwCWooJOIF-bG9qspVtQ#NHZC@d*>WbM-Wif)Uj zi&Tp^#tc}XS#p=;>Z_I(f!`nM%gx=>6ULYJTlX9HeK8Rti6VU>DIyn=9+MoC(&<;| zyET`z3$;({y-X@i2n;gq*Nk4~2WLagl}(+8l8lb6+eey~Zlg0(Gkffk*L=OxUUZ&Z zZ!2&03GfNXkX(^OVjg2)V|eh}HC@zT%7a~(V2^hX8%7QjXDH2~5AY9+C;VGe+6!9o zojm!6EMg2My0~7d+NyLG_^LYXWSSwGv)8=VTvuDq=Fa5#u6gg@Kk-2=QJkpU8s1vV z)5=Gfe>^Bju3vzwc3ux}&UTM?FOK~zUm#yNiKef+VX~pMVJP6;ZR&UMT>BdMn*P$x zm4^SmIFto{6T5~ahdvG;`@IIvJ?k*51nMj^Ru&$k`{-(xC)AbbRa%e`Bc+G*#m>9@ zxQe(9VaHtJT-DskzKZ^&u*R^NKCBQzxVLv(IJHdMiUp1ZtaWUka4`r~(b#ZL33PGS z@QiWlN%plWyeb%*xR403Fq#GIOu8nYE3crpPj}#V&gmJoYD{{+gdD-oU}fov>JI3- z@Pekl%XFaHg6Bxl9w(68B|DSA-}k(1y43s0cW?Sk99W*C!kDs~rjmBeyTs-x{ZrwS zx0L4yZEkXn(mlnTfUEnFg_gr|V8iA0kLyPSVKL!B$@tGyva?C0Us}oPrEMq}(C^H+ z%x%YUlOUvS%|({r>(ZTnSUArTvn9O}osPPWR*t;1a9Uknl7EsGmF7-LW6K*UHVxfRB3>>Yc%J-*8t0W5>6%EJu;CLHX1906*FLoAufrKx(xXJ8D9EGem>_hDC znFyRq%iGh_7wKQ^q2&%U%D5|iymoM=Q-fG`%&#V5GSITUUHKn;{EzeI#j%VSRxROe zk2|nEE`O45NMkK8otgvM_4r-CK8$=TuPB#Yz^`I+TXbjDY3^R1I-FYa=)V-Zk}XcB z*L7~2Sas@leQmlsm3HTJuUgA=Z$8%B?R;9hThAPM#sQT%u&=?G$A&jErm3IfxH)N?Db_X!=5k~2WfDimqY?Zq;|-Dv3r0S++42#%kNsq- zA!#Nn3qk|Dh5>;9#RY)`UV#D+eo(ysT#JEHfq?(62Lk~Kvjlk^#vWTQ4@T_d?WNK>XY+>&b4GP@|tblcp z&~yd?!6g56fJ!Qn0XN>D^Oh7121h&T?~ml zY;EkExjpzueoJrzuYXkoNQi!mxLEU%XviuMiP$@t60tKdGBA?x!x0e?@j989aVv?6 z{Z$9v@sU`#xHxbF0PgPY4DKup_D<#iCN3^603$PinVB9aLGSEo=VItVZ|6+<$H+fB zqNdKqPL>WXmiBf;zw8

WFL;fgQ-AARU+|lC6nT6KfJzbH~VO7wKL>3^#DTcrVp>R<|i@(%P@xrhsZSnW?Q zqKf<@E?5u{i_JumGgGD8Zosn=+E-K>{*f~{8#$TKiBAAOzlQB9CMCK?(8w~1=%V*7tj!Ud-^b(P`TBIFdETa| z#*%lN5dfVsYqm>&7o)UzU9V}2Z`J-raWvFqy8?w{d^B6?Z1r@-Mt507ulv0J9w|gt z0DoJ714$$l0rscd`Ff}IxH(+9Uw3$a2fl&Q3eTsGy>&%k1$%i&|MlTa{w~)d;ZpE= zu{3a@a9U59v@TF1ZAeiu-~Q*{Uj3oLg4kje=LKG1*Qb? zZ!X^jK2p{^oweh^nU^w1Cmf-McjQ3uqDK4uN%gzHX|%Y29$C}+b{k=8S%g6f4&P}gO1srOkwJ8u$rSeAgcaJCaaF|d*I#$9+AJ3?r}sA zY$Pyd)vrSFor{#GM*G|L^BvGrSbG)CKTjy(W3MsHt zds1uy!*KNf98mxru`6cE|8pcrloWslXjP)`UHeY{rw5`$L4r1*2R=@PxI(J!AAa{r zG$#vIqURW=Mh@`T(v6THm>!{utK8oKD=`N{z3_snnE5yPYhj1cv}+@deEy6Ni2@i< z09A1QpO>6Ph+V11MNwD&zZR%a0vCNZoeH(`cnU`UWeyXp7DA^=r9N(!;hz{WD69vp z;jz*`yL3R2j7ar7+BwMtiqvesM{0wa3Qd;!MWaZ_jdIhSAwqvvfwyPpiCwyu%^Hxo1c;68P9N@;Of1zTKcip-rV}IpME8U6f~DB0zFfy%|UD6NpK>$ zsabE9YSV4k1&~ecRA9}9X1e8p#pWq&~>3+FY6zW)d+X^F01Fq6t41LiI-d>eN z2Hm!lFY0<+)jPCY7ns^MhWu~OS9zkBGx7p#qYUjfiW90M2|8{|ru(0z8L00+=ZnX( zqZ|2fz`2NoA+do&BeC{}Ay?n52?pAhS2i*K5HIW|EEDVTH1Krj7mGOhh|FKowrq|M z>aIg2dv`Q9n&EX+YNxL1qm;(~e9=FSXWtR8?&YC?;zOuE&5-K|u}UyX$#w zWlwa;opH-56aqoj$F2e=w-VB15d3$}KUmlFc zTbT##0`afzye%Wz4G|}&+q@d*&{w<6z2)+KbXiH{1U|ex8N7gFlC~2o;FtBnf0iTa zJH|AWLakf`@m)OJZP5^VChr6MOoOHJ3c@{b7T~SJnD;uj3_wwDInTeY~sJr}K2joxvru z8wU+7qAZVXSdoyL=?WSM{Mu!MkBA5_SsgwT{7*LhtQ-}t$Mc~w8nT#}7Z~a~-!H&o z>rz?Fm-h$S&)Ukx>UqXR?Yka|Fq%t@eHZ1IQLGjYB->&*w_%*8B`H%?)ilYiQ7teF ztwYS*%x4EYC45iLN6|0B-g^z7wCuz%SYG+O-mcg;Azyr>%x1k?b(W9r8{>aEF}{-g z2!*)MHIYnU2r94}Co;yiptB=A8;Lnt6L9spIE2Ph?6pXh!KQC}QMp-V z<9yx9+y#IT6uB5HP!?O|W_ zwx3OjAXh58nZ+@l4`J#1Fu`P(8{k9GeHWX9;RNWaj~4Cw{`^LnzUH?>nF4|~b%pVMvXB%=|u-k1Pq7!v|-KF&9D>d=&(XKlx0LeK>0 zK9h|?&ddOBEjz}v6@BkhwjCfucn3EAbdKm?jB3hwzKSdI^*%iuG9+mDAZ7GQ7SBGe0@B|vr<{4t?n?=YO=9} zrpR;}p>`uw2J`sFX0tfuclV-C3?}D4Mc}?_h%yDZy&NL+j`;W_yD`lK)>k{Z`*k@ zFO;-qnyfqKd$)$=GieUKWzNdo)B^^e63hQsl2>Ffcm`fv#A7^&FjJ;l-cek$3(ly% zh-cSAoI)ek4V6DpTdG`~<~pxQS)zKsPKUj(5FRjtdpS%RyEtRp@LkMJ>cjnNmrYo^ zY&eT=D{xj@uYZ37K1Ro3DKiXG)CNVm4~x>h)UW751I+mWJf{r_p1?uA+Si1ImiQGT zmaa>D>-HyF8#NP|+4l1;>oitNR9fqwI*h0UnbsR!owC2Ainq$sMY`LijSpHZ5LUFUZP)shA{!4GP^5fsaV;yG z!q_X8Iei&NEaD?0rns#`2WrmWaqLE7)Wr`#z1W50)Vohn+^$F~8f{|@5io}3q|Cvb z;59CC*gJ-#!4DK}^1l9Ic-Hk~=M4vC&eY9eGs}61=SpgWqbaT2c8-bZfFIYQ7V6Np z+M(aE_nk8MfSto5T1d-%-P3&iBX{2c;Eo*<87kLs3=|N$Q0i1neRdN`zUfiRvl##$ zWWgBLq+Bf5>#on<2iCRW_k@$CvV%Rr)%#(ge_X{n?PFpx!*ovnVvD{blYa zj&={If}DEUEY;x-1=8ZL*U@B#z-^c~4NG`*awfpfF2bK39^RvA!+NE8*{Xage&zc3 z-pc>=p_rIpV$*Ry)x51b;)yOPfCV?x(sSY`!Q6Qa?VY(s4}1hXc#@iiVTiDtVv}({ zm>gM(QEDvT?V@p!?Ml-Fz7mQdPQ@S{nfZNNkD_#AS~J)L2p$xnAz4rueLCaJn3Qul zOHl$-+^9LNXwM<|&^OpZ75!pZ)3)(FQ`cEV@dqDi_)p%zZAXYc_TrDIOO&XDi*d3| zw)Cl;Re{IoLxJ0IW_Sh_rqd64kz37A&~g$Ws0=YOws{nW*xPrxXS16eXkYK7QPzoKS(`fdakqxxrlA> zhkP#3s)gXKHz>gMxS8#_@@4nIaM%~SI;dcX{J!e8Vq0WBnROPYC6FpTFXbV9?T_p@!sj}Z-Y&_{OpgM( zZIY9B+W}&M!7e=h`>d%HO zCE4RVtwdF^T9{};5jMLVGtUwmD=E>qoAIQ~0 zM0stxT*<&e&1`<(ihBRMwPg9Ez5;`wE#K&6Bc)0Wuq>|Ttgi%l&)?LUtzkXJDO{{9 zB`-s=xURyxo^^J7v44afTU03=$g8` zed-rUyNQo1P_$f|0}1s}p8sLM?&&)A`}gYZ*S2Wp@x#Q19oz{WjoOg~l%cPv7Y ziro*V3+MM6E^5Ej=^}udLvY)fSmyTnWN8CH<(mS`V?@n86|yc{)Lg;js5|kiRElL% zjj(78UD%TB-%j_E_4SNZ>jJ@{am7lW%tGMq9ft_a-}~JgrrQ*0Y;G1dF+qG8p?uC8 zMCPZp*lRbC-8Z7(M_Q!`4PasK^E@>E4%LDGcvJ|;0I-JB>q<*Y5ptf>tkNiByEXw| zHFz25-(7Ks$Tl5IOt~mJj8I>`%;xkX@zvog8Dow1F-^|RHTme)cP3y;Mt%9n$C zd<*+Imh*1Sl}p?=EP{Shv?8%eh2@}R@brqVVP0*6qk!Q6xX%P{)&S=l`C*rH9C(Pq znjXf&Y7Dcu(c-vM(ogFr;JXFGbf%hV+q~vBpGi^)pDG+8WXH}5u=MPs^;Vb8y&oc1zcwEh)+wSJt-tmk8&k5;PlsU-Ofx#@& z%K?7}PHBm&dCwmW+?vm4>24)%ICy<{Ey~HN)c5(h`^`-UV-m{vW1kor?^=!H;BxS! z@|dkXL`EqHT~FYp!sYhFJ~tZ{zTPIib`!$>4O6t$YKx;D7d;HTN7c31%X*0-X@9ui zzVeeDt34A-?>M~sq2+QzsUbQ|HVl@s3{_9fp#387L5_TkMWlSvZ%E_GqujGIfN@!5 zuwx7_YUb4smi3_BwQ*?F+udYO0V==2WVTpS!Rp+~f@scm&(*WR_vJhAhvE`g<*^i8 zn0XmZr7B&Fs^or1fUk+Yu5oDsr6zIVXG ziz^!((`<*O+ua<1*OFee2jvWjfLjBJ-gv&k09 zY5Op~H}1=i)CqXd;=^Pim>PH-ZBd@>_e?`~Jb4%kThgp| zjz|*%_{ePAkgEJ`Oz;bk)~J7rR$=}FGd_P11k0*;s+|4w8lago_9Lbn7=_w2^biVKR|LQO8y}=5)^-sMA{fazRDl4p#lYrj6vS!i{lkK{O`B| zcTwX4DDT#3_x=!x_h>*qRG?m^!q;{DsqIe}?;woeIjNIrRnO16QGfF)l1*w**p^}r zx$l2mORxj6M%{V0(jUO`zgroAaZ8_ds*n6hGB1$h+}-a}AN_IpjU0%BzIEP7wf=E; z0AX|&2+Z(-yWV%K-#xoX6QlzL`ma83CHdn)0TP%vRR^>$fAD4C?k}um@R@dQ*A%nx z56#MPf*HVc_{`|?LxIIzlEfMUZWl&GJa)2j@s$i1 z>y@U{j9JCo^JHq}{flI;%L&W4<)MvD%It=x#cWh3KtJlaj-@jKI(sHo9w%pvkD2j2 z``6{YOtN4CUJr+!T=mH;A^y=pA;X#TWGW0U%z|L@Oos6#&ba2)D_|(@D(pxW z$vY|0I9~hrYEp}bGesXtuKe5aYEnl)t~8#K=*Pp61sa~rZ_e!nejy(gn?=U<<7n8wiG753x=D^cGp_VxVZ5-Yt&&)ZRBgk^8pF3szw>237ELIJni54AxLz zuF&;KC9U?&^~B=Q?Z~9L-ce!k`icCLNNnpQ(K+5akA_lTSEAIpoAXBf5@$xT7iG4< zU1mj%rR~UIuJIAh3F{rs*kF`9+9ZV3YO9$#WEVK~FBLXSVsd%u|h_UaD= z=x=Y4Q~C8EAUJ?u^!zjvwnKIfYbe=%(ZQj2d7l&W-i?Pqjb9D>zf9EIA7t`md$-@6IQt)8d7dMqzW=szS%Wkkr)FT%A$L1&bXG zC5!`s{Z#}fU=;!n@VVqA{_3QnQDyRJ?^kM5p0>Lu1G&N=ACdQ_m;d_Jgpl_G>~3BeQax@n$GQ^52xi@qmpD zB=OAiqW$haFvw^S43Bw*pMw7*AW^7QXs1hkw2@75y=r)n`p>vCSXahkiWHmk1uxh? zZQ-6~FmswYnVPHie+F=XF^%0{uKc^l0_#F=X;li1$3fNk{3kbpcE8e|#4#!MH|?6l zf{7+)2vO3*A~p>E zZ(z@08NXParCvq-vw-n8dfJo!r0+uZcl4?oV-FwxWi$ooco4=DyN~G#dg+NH(J%rG_82K53T9l z?ZnpIKAq?D57}EZsGOf4m#f3Oj9B@~`@Hre*JZv^ zxXm9xoQK6k<4Q)CsVej7=#6UT_h9+l=QgxIm}zz!ijW&9tlR7Q)T*pll}kN!;^ps; zp?I8n?(H_&?p5*ocC>?c`u38>4%1)l7rI6(Xnd7T8_oeoP88!gO`LbPrG8#`ykT~J za7w-SIm+z)6|UEf*Rp-+J#Zn=W!^2*>Y-TQc0U{zPycK?o^I2lx-4oxArU{XkG)>A z7qCuL?vPq%P zPShO6f146T3z!^fM6_AcqImDFqZadJ@?%c?(|x-Ah>5SwE3jgTMzuP zg9soZYf8;;d7UX$&NG_#4V`z{@IJ?2y`DySRv_em*lqB#Jj55kGJT~h z2r}<8la@VLu}d&luXIu$*X({-G|mjQo`udLNWf%B1cHs56AUiJFNef$VcuXc%?8TeLM?}D(N0^quYh_pA@Lj_WhiLt+&aVyxUgPv{CaoIaO5@y%(G$Aogzp>$jDtpynt zGy_`6wxsWYtK~~)Ro608aa%Y1)f02$#TIrG495M2`_a0(UVOdzShyF(?V_KzWtB)V zz12ces8#&4;09uKwI-?bHF73sGwTAg^CYjb+jk`WxPcLsG7n9x4L@&m@s41uvQmV7 zRakCj{3-Zczyy#_a5P@FO^)IM}=z7;bJ9eL#wiKO1`Pz^sX_#9h+X46d zZALDkidzR&zDrU6^H**tPUCN=sX3D&?M%)YX*1`?8~SAh22L^V&pu~_FDLa*-xB1- zbiuMBE<3$GoO1Al^1Zf1(bKY|$CGPZulT)&y9=tziUD)izjCo1>lw3w;B&Yhu*E=#5cCuY6w2wK zB=qlKDN3)9^Dj|0j%`148tj)@Qjx7J@=-`SqybLk?w^-K)5mA+v@c?tXkZ)f@E{TA z)fV&#Jt<@e3`)IBhu2f}Q z%u+7g`%>AwD7>HTp4AZ5JEqNUv7%qXViajz9j&pqdFA-=C`@?{nhGP3Wv7Q)*`#UI zDZRolW5lsub5INNweglH?RY%!E=RlRJ2cX#rR$DR-p7*F`{%1Ol8Hxu@4KWKWNs>} zZhD?)8i(2H@2cvsyl%t}7n+wJIp~@U|tBNBeBap)_I7$9;dzC}$2kDRuiX0|6_^O8qQhUb~i2 z2XmXZJI-=#zPo!#6ISU>;g5N3QbZ2-Jq9)8lkynOzRuA!SA_W7kLF|vThl!HUXl|T z{A*mZW>#DMPs|d;DZoua&}XT+l@VqRn>Oy^=s8BXYH?;S^lD_h{p$u9O}yk(Du<|3 z-iu{xxyC!??!XUa*EZV&b4PA@Ux5c~yD6=#8fdI`b=1-F9BlICS+ng*HCR}^61{JVY_)3*w^#Jx z_x`3Mo(1m?uS|K}jzvvz?AZmYBjL}Vg%fS492}fkJ{Uy;=4e;U$Zzkp?IkB8(DL6$ zTrD>yIP!Uqq#R@=L}q*BO&?t3YviENKOJ8vX+R=kMP+wg_d+YGu1bf+zvIZJs}`HZ z0&mE#MM5JX0i9MfzT3*cRjbts+$kD}T)SRp42I8Xwsg)YFHInXjJd>5`DDLlp%z>x zg23BHZIeh*rlV#Xbv#fb@z8p(tPx$IUgIsUk+s-+!xOt8kC1zC20pKX# zPXpZO1;&r6l;ZD~3ngg1&ti!7kFcjBAP+cAe^!by+Sz zi?ooN)hX2sl~(0j5?GirnoJixBS%7W$5S@ z=K{WGfew1mNtAoJsbSpcr-q-?(wWcuhjn#@K2+$O2BfXG>xq?0m=ikGh^1 zHjt6{2dAtu4D%_~rOJee$-*OJH6ue__S)7H5Ca_EkgXTR84f}2&0kj3Eo7b~7Ba9l z%5ETE@gpF1_9EF5QPrejvam?~t5r|L!^!cNgGh1}HZL<&5lpBmj_F+InOz+KYjF5Q z)gJZq#NT1Z#`KhlV;WNxae@?>S0K#3ujwouz!C3Ke&-uW0=r|6>YE0KZa2;L zrFbgmaI9FTRGK7t6|g8}Mq9gb;TN%55^*J=8td7^5Y6iu!^ebH7!V_6ijJSGY9Tp$ z`kwf34i6Y^MHaAOM`K`mK+ekuB2p_-x(WIS&CUCLEB?KF&hCD;e;1Jo6~eaK<~OC% zyslxX28jgio(Mz$GieIzu z;dcKQT2_swZLJyb_8;%cdD_is%x z1_qmsh&ann>zYXidK(`LOtm;9%rUN6!x?$ngMZZ_6@F)%IXpN^yw0pvfpF)H79vis>`CdY8?M08IehAX(+R zkJIY1KQ>HG`Zags@|oc%HL?`KkYNx4^P^J61(DNF-f9!+z>obbcn33s=#j%=?QSJx!_`Ea3a3{f%fovoDILL<$sT#Pm?88+HN&KaZrRZclXnVF&(qZ8EP!!v9_tgB_}G&W$2!%IXfH^~*tk?W1lYL&Y>PlpQ=hp$^jR_qJbTFz!Jz zDeOnetGVaI!@sKIeug{_wgs-7gdyCkr>dz?7_)dO@s5Pl#UCwS5F!1#L9X+!WGh5#hQy)vd<5JTa6p* zLygg(XS0m_=6cLuz;OZ@@wk%hA zvXPNoy|!2JLizX+)=BMsW=x~Y16{regCTf+S98;m_YTipQP->9_Nk83l4dOD(!gI( z!qKK;*2BlDu=UjV7)J{gIFRU_aATfNeNX4&zDahm^hY3`%Uf!j4y%3G^1QucBKk)e zOU|>TuAk&ckNk4*j$1g2WE(v9iK9Y9P}(UAe-x-eisI)e9;oit2Ba|UE>rU6!bFhfljF*tDx+hg%lZ?r?) zX)}6l5(=#sxM6L^GGJT$#`gFGXTbAw-&^b8&s0oIZ-sfYvOav6w6SIE-6+!}e7RqK zRz9gv-26e`F#TY?1~)7yne-DS!Ske;n29q~+d#@+l9H>|c96Uo%zc3}^TyWo1aF3- z57n@}r^7XP&@ZjUK{=bRQ79~nEJ;C!_J*tH>LZum8raZKX`0wf`N^7TP{*0-re77t zh&OQ0sHR1vR0+WFC@Y{zu2&7r$GjL6CrQ`P{9#qb9HuXLYBiT(sK#pZJ%>$1(hd$r zwNsqlh5mTf@qFT`jScSkr)E9Aa&|IyA|TO$nUYTCBIc5~I>$7^QViwJD$=B^hv!(7 zf^^o!&!E$*Nrmgs)u^a(otw8tRC}!y4wq^822naHT7GQA3X zScavhrH%BCP@OH<%t|CzVkDHaiBE72fk1s1fCi60t7h$UWkrDZtc>R%Avdis^zuUC zBk92HCkjSpfC5V8+(9_MNs#N0y&$F6{I6wnuDG{l<8pDXpK#ZX7PO$1bkq zV3aAXM7)S8(zk02DLgYFYVFzXZovr3&op={Mi?4zQzDA((#7)0iY#WFPuaB8@wc++ zVqcheG@61$a~UjTv~Namn{alen^L>9`86cx{%8W3N`x*{mrhW3W-K9+-JF zn}Q&?RQ4J3)B)6oYb(4IilYa~faI9M!im?ML;@}o2_;Adh0RhaNlDnm&MS2_wWJGc zI_1x_*_?Cbu#yu9K1{yL0qVke2{kDdpcI8qdJ+Ct&@G|oNVobiI$IA-jLY_1V2 z;0q%I=)tK%*F(E>XfMV~aNN5@M%V!ob>?BFv;rI^YI>leZ&ypDf k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + lm: + The language stem base name. + + Return: + An FSA representing HLG. + """ + lexicon = Lexicon(lang_dir) + max_token_id = max(lexicon.tokens) + logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") + H = k2.ctc_topo(max_token_id) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path(f"{lm_dir}/{lm}.pt").is_file(): + logging.info(f"Loading pre-compiled {lm}") + d = torch.load(f"{lm_dir}/{lm}.pt") + G = k2.Fsa.from_dict(d) + else: + logging.info(f"Loading {lm}.fst.txt") + with open(f"{lm_dir}/{lm}.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + torch.save(G.as_dict(), f"{lm_dir}/{lm}.pt") + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Intersecting L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}") + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info(f"LG shape after k2.connect: {LG.shape}") + + logging.info(type(LG.aux_labels)) + logging.info("Determinizing LG") + + LG = k2.determinize(LG) + logging.info(type(LG.aux_labels)) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + + logging.info("Removing disambiguation symbols on LG") + + LG.labels[LG.labels >= first_token_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set LG.properties to None + LG.__dict__["_properties"] = None + + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 + + LG = k2.remove_epsilon(LG) + logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") + + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + logging.info("Composing H and LG") + # CAUTION: The name of the inner_labels is fixed + # to `tokens`. If you want to change it, please + # also change other places in icefall that are using + # it. + HLG = k2.compose(H, LG, inner_labels="tokens") + + logging.info("Connecting LG") + HLG = k2.connect(HLG) + + logging.info("Arc sorting LG") + HLG = k2.arc_sort(HLG) + logging.info(f"HLG.shape: {HLG.shape}") + + return HLG + + +def main(): + args = get_args() + lm_dir = Path(args.lm_dir) + lang_dir = Path(args.lang_dir) + + if (lang_dir / "HLG.pt").is_file(): + logging.info(f"{lang_dir}/HLG.pt already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + HLG = compile_HLG(lm_dir, lang_dir, args.lm) + logging.info(f"Saving HLG.pt to {lang_dir}") + torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/WSASR/local/compute_fbank_librispeech.py b/egs/librispeech/WSASR/local/compute_fbank_librispeech.py new file mode 100755 index 000000000..a387d54c9 --- /dev/null +++ b/egs/librispeech/WSASR/local/compute_fbank_librispeech.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LibriSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=True, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + return parser.parse_args() + + +def compute_fbank_librispeech( + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, +): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + num_jobs = min(15, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + if dataset is None: + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + ) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = "librispeech" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + + if "train" in partition: + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + if perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_librispeech( + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/librispeech/WSASR/local/compute_ssl_librispeech.py b/egs/librispeech/WSASR/local/compute_ssl_librispeech.py new file mode 100755 index 000000000..f405c468c --- /dev/null +++ b/egs/librispeech/WSASR/local/compute_ssl_librispeech.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LibriSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import S3PRLSSL, CutSet, NumpyFilesWriter, S3PRLSSLConfig +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_ssl_librispeech(): + src_dir = Path("data/manifests") + output_dir = Path("data/ssl") + num_jobs = 1 + + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + ) + prefix = "librispeech" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = S3PRLSSL(S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda")) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + storage_type=NumpyFilesWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_ssl_librispeech() diff --git a/egs/librispeech/WSASR/local/filter_cuts.py b/egs/librispeech/WSASR/local/filter_cuts.py new file mode 100644 index 000000000..fbcc9e24a --- /dev/null +++ b/egs/librispeech/WSASR/local/filter_cuts.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script removes short and long utterances from a cutset. + +Caution: + You may need to tune the thresholds for your own dataset. + +Usage example: + + python3 ./local/filter_cuts.py \ + --bpe-model data/lang_bpe_500/bpe.model \ + --in-cuts data/fbank/librispeech_cuts_test-clean.jsonl.gz \ + --out-cuts data/fbank-filtered/librispeech_cuts_test-clean.jsonl.gz +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=Path, + help="Path to the bpe.model", + ) + + parser.add_argument( + "--in-cuts", + type=Path, + help="Path to the input cutset", + ) + + parser.add_argument( + "--out-cuts", + type=Path, + help="Path to the output cutset", + ) + + return parser.parse_args() + + +def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): + total = 0 # number of total utterances before removal + removed = 0 # number of removed utterances + + def remove_short_and_long_utterances(c: Cut): + """Return False to exclude the input cut""" + nonlocal removed, total + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ./display_manifest_statistics.py + # + # You should use ./display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + total += 1 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + removed += 1 + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./pruned_transducer_stateless2/conformer.py, the + # conv module uses the following expression + # for subsampling + if c.num_frames is None: + num_frames = c.duration * 100 # approximate + else: + num_frames = c.num_frames + + T = ((num_frames - 1) // 2 - 1) // 2 + # Note: for ./lstm_transducer_stateless/lstm.py, the formula is + # T = ((num_frames - 3) // 2 - 1) // 2 + + # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is + # T = ((num_frames - 7) // 2 + 1) // 2 + + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + removed += 1 + return False + + return True + + # We use to_eager() here so that we can print out the value of total + # and removed below. + ans = cut_set.filter(remove_short_and_long_utterances).to_eager() + ratio = removed / total * 100 + logging.info( + f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." + ) + return ans + + +def main(): + args = get_args() + logging.info(vars(args)) + + if args.out_cuts.is_file(): + logging.info(f"{args.out_cuts} already exists - skipping") + return + + assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist" + assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist" + + sp = spm.SentencePieceProcessor() + sp.load(str(args.bpe_model)) + + cut_set = load_manifest_lazy(args.in_cuts) + assert isinstance(cut_set, CutSet) + + cut_set = filter_cuts(cut_set, sp) + logging.info(f"Saving to {args.out_cuts}") + args.out_cuts.parent.mkdir(parents=True, exist_ok=True) + cut_set.to_file(args.out_cuts) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/WSASR/local/get_words_from_lexicon.py b/egs/librispeech/WSASR/local/get_words_from_lexicon.py new file mode 100755 index 000000000..0cc740b36 --- /dev/null +++ b/egs/librispeech/WSASR/local/get_words_from_lexicon.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +import argparse +from pathlib import Path + +from icefall.lexicon import read_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + Generated files by this script are saved into this directory. + """, + ) + + parser.add_argument( + "--otc-token", + type=str, + help="OTC token to be added to words.txt", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + otc_token = args.otc_token + + lexicon = read_lexicon(lang_dir / "lexicon.txt") + ans = set() + for word, _ in lexicon: + ans.add(word) + sorted_ans = sorted(list(ans)) + words = [""] + sorted_ans + [otc_token] + ["#0", "", ""] + + words_file = lang_dir / "words.txt" + with open(words_file, "w") as wf: + for i, word in enumerate(words): + wf.write(f"{word} {i}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/local/make_error_cutset.py b/egs/librispeech/WSASR/local/make_error_cutset.py new file mode 100755 index 000000000..8463a380e --- /dev/null +++ b/egs/librispeech/WSASR/local/make_error_cutset.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 + +# Copyright 2023 Johns Hopkins University (author: Dongji Gao) + +import argparse +import random +from pathlib import Path +from typing import List + +from lhotse import CutSet, load_manifest +from lhotse.cut.base import Cut + +from icefall.utils import str2bool + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--input-cutset", + type=str, + help="Supervision manifest that contains verbatim transcript", + ) + + parser.add_argument( + "--words-file", + type=str, + help="words.txt file", + ) + + parser.add_argument( + "--otc-token", + type=str, + help="OTC token in words.txt", + ) + + parser.add_argument( + "--sub-error-rate", + type=float, + default=0.0, + help="Substitution error rate", + ) + + parser.add_argument( + "--ins-error-rate", + type=float, + default=0.0, + help="Insertion error rate", + ) + + parser.add_argument( + "--del-error-rate", + type=float, + default=0.0, + help="Deletion error rate", + ) + + parser.add_argument( + "--output-cutset", + type=str, + default="", + help="Supervision manifest that contains modified non-verbatim transcript", + ) + + parser.add_argument("--verbose", type=str2bool, help="show details of errors") + return parser.parse_args() + + +def check_args(args): + total_error_rate = args.sub_error_rate + args.ins_error_rate + args.del_error_rate + assert args.sub_error_rate >= 0 and args.sub_error_rate <= 1.0 + assert args.ins_error_rate >= 0 and args.sub_error_rate <= 1.0 + assert args.del_error_rate >= 0 and args.sub_error_rate <= 1.0 + assert total_error_rate <= 1.0 + + +def get_word_list(token_path: str) -> List: + word_list = [] + with open(Path(token_path), "r") as tp: + for line in tp.readlines(): + token = line.split()[0] + assert token not in word_list + word_list.append(token) + return word_list + + +def modify_cut_text( + cut: Cut, + words_list: List, + non_words: List, + sub_ratio: float = 0.0, + ins_ratio: float = 0.0, + del_ratio: float = 0.0, +): + text = cut.supervisions[0].text + text_list = text.split() + + # We save the modified information of the original verbatim text for debugging + marked_verbatim_text_list = [] + modified_text_list = [] + + del_index_set = set() + sub_index_set = set() + ins_index_set = set() + + # We follow the order: deletion -> substitution -> insertion + for token in text_list: + marked_token = token + modified_token = token + + prob = random.random() + + if prob <= del_ratio: + marked_token = f"-{token}-" + modified_token = "" + elif prob <= del_ratio + sub_ratio + ins_ratio: + if prob <= del_ratio + sub_ratio: + marked_token = f"[{token}]" + else: + marked_verbatim_text_list.append(marked_token) + modified_text_list.append(modified_token) + marked_token = "[]" + + # get new_token + while ( + modified_token == token + or modified_token in non_words + or modified_token.startswith("#") + ): + modified_token = random.choice(words_list) + + marked_verbatim_text_list.append(marked_token) + modified_text_list.append(modified_token) + + marked_text = " ".join(marked_verbatim_text_list) + modified_text = " ".join(modified_text_list) + + if not hasattr(cut.supervisions[0], "verbatim_text"): + cut.supervisions[0].verbatim_text = marked_text + cut.supervisions[0].text = modified_text + + return cut + + +def main(): + args = get_args() + check_args(args) + + otc_token = args.otc_token + non_words = set(("sil", "", "")) + non_words.add(otc_token) + + words_list = get_word_list(args.words_file) + cutset = load_manifest(Path(args.input_cutset)) + + cuts = [] + + for cut in cutset: + modified_cut = modify_cut_text( + cut=cut, + words_list=words_list, + non_words=non_words, + sub_ratio=args.sub_error_rate, + ins_ratio=args.ins_error_rate, + del_ratio=args.del_error_rate, + ) + cuts.append(modified_cut) + + output_cutset = CutSet.from_cuts(cuts) + output_cutset.to_file(args.output_cutset) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/local/prepare_lang.py b/egs/librispeech/WSASR/local/prepare_lang.py new file mode 100755 index 000000000..d913756a1 --- /dev/null +++ b/egs/librispeech/WSASR/local/prepare_lang.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input a lexicon file "data/lang_phone/lexicon.txt" +consisting of words and tokens (i.e., phones) and does the following: + +1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt + +2. Generate tokens.txt, the token table mapping a token to a unique integer. + +3. Generate words.txt, the word table mapping a word to a unique integer. + +4. Generate L.pt, in k2 format. It can be loaded by + + d = torch.load("L.pt") + lexicon = k2.Fsa.from_dict(d) + +5. Generate L_disambig.pt, in k2 format. +""" +import argparse +import math +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import k2 +import torch + +from icefall.lexicon import read_lexicon, write_lexicon +from icefall.utils import str2bool + +Lexicon = List[Tuple[str, List[str]]] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + Generated files by this script are saved into this directory. + """, + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + """, + ) + + return parser.parse_args() + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_tokens(lexicon: Lexicon) -> List[str]: + """Get tokens from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique tokens. + """ + ans = set() + for _, tokens in lexicon: + ans.update(tokens) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def get_words(lexicon: Lexicon) -> List[str]: + """Get words from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique words. + """ + ans = set() + for word, _ in lexicon: + ans.add(word) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: + """It adds pseudo-token disambiguation symbols #1, #2 and so on + at the ends of tokens to ensure that all pronunciations are different, + and that none is a prefix of another. + + See also add_lex_disambig.pl from kaldi. + + Args: + lexicon: + It is returned by :func:`read_lexicon`. + Returns: + Return a tuple with two elements: + + - The output lexicon with disambiguation symbols + - The ID of the max disambiguation symbol that appears + in the lexicon + """ + + # (1) Work out the count of each token-sequence in the + # lexicon. + count = defaultdict(int) + for _, tokens in lexicon: + count[" ".join(tokens)] += 1 + + # (2) For each left sub-sequence of each token-sequence, note down + # that it exists (for identifying prefixes of longer strings). + issubseq = defaultdict(int) + for _, tokens in lexicon: + tokens = tokens.copy() + tokens.pop() + while tokens: + issubseq[" ".join(tokens)] = 1 + tokens.pop() + + # (3) For each entry in the lexicon: + # if the token sequence is unique and is not a + # prefix of another word, no disambig symbol. + # Else output #1, or #2, #3, ... if the same token-seq + # has already been assigned a disambig symbol. + ans = [] + + # We start with #1 since #0 has its own purpose + first_allowed_disambig = 1 + max_disambig = first_allowed_disambig - 1 + last_used_disambig_symbol_of = defaultdict(int) + + for word, tokens in lexicon: + tokenseq = " ".join(tokens) + assert tokenseq != "" + if issubseq[tokenseq] == 0 and count[tokenseq] == 1: + ans.append((word, tokens)) + continue + + cur_disambig = last_used_disambig_symbol_of[tokenseq] + if cur_disambig == 0: + cur_disambig = first_allowed_disambig + else: + cur_disambig += 1 + + if cur_disambig > max_disambig: + max_disambig = cur_disambig + last_used_disambig_symbol_of[tokenseq] = cur_disambig + tokenseq += f" #{cur_disambig}" + ans.append((word, tokenseq.split())) + return ans, max_disambig + + +def generate_id_map(symbols: List[str]) -> Dict[str, int]: + """Generate ID maps, i.e., map a symbol to a unique ID. + + Args: + symbols: + A list of unique symbols. + Returns: + A dict containing the mapping between symbols and IDs. + """ + return {sym: i for i, sym in enumerate(symbols)} + + +def add_self_loops( + arcs: List[List[Any]], disambig_token: int, disambig_word: int +) -> List[List[Any]]: + """Adds self-loops to states of an FST to propagate disambiguation symbols + through it. They are added on each state with non-epsilon output symbols + on at least one arc out of the state. + + See also fstaddselfloops.pl from Kaldi. One difference is that + Kaldi uses OpenFst style FSTs and it has multiple final states. + This function uses k2 style FSTs and it does not need to add self-loops + to the final state. + + The input label of a self-loop is `disambig_token`, while the output + label is `disambig_word`. + + Args: + arcs: + A list-of-list. The sublist contains + `[src_state, dest_state, label, aux_label, score]` + disambig_token: + It is the token ID of the symbol `#0`. + disambig_word: + It is the word ID of the symbol `#0`. + + Return: + Return new `arcs` containing self-loops. + """ + states_needs_self_loops = set() + for arc in arcs: + src, dst, ilabel, olabel, score = arc + if olabel != 0: + states_needs_self_loops.add(src) + + ans = [] + for s in states_needs_self_loops: + ans.append([s, s, disambig_token, disambig_word, 0]) + + return arcs + ans + + +def lexicon_to_fst( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + sil_token: str = "SIL", + sil_prob: float = 0.5, + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format) with optional silence at + the beginning and end of each word. + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + sil_token: + The silence token. + sil_prob: + The probability for adding a silence at the beginning and end + of the word. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + assert sil_prob > 0.0 and sil_prob < 1.0 + # CAUTION: we use score, i.e, negative cost. + sil_score = math.log(sil_prob) + no_sil_score = math.log(1.0 - sil_prob) + + start_state = 0 + loop_state = 1 # words enter and leave from here + sil_state = 2 # words terminate here when followed by silence; this state + # has a silence transition to loop_state. + next_state = 3 # the next un-allocated state, will be incremented as we go. + arcs = [] + + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + sil_token = token2id[sil_token] + + arcs.append([start_state, loop_state, eps, eps, no_sil_score]) + arcs.append([start_state, sil_state, eps, eps, sil_score]) + arcs.append([sil_state, loop_state, sil_token, eps, 0]) + + for word, tokens in lexicon: + assert len(tokens) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + tokens = [token2id[i] for i in tokens] + + for i in range(len(tokens) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, tokens[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last token of this word + # It has two out-going arcs, one to the loop state, + # the other one to the sil_state. + i = len(tokens) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) + arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + lexicon_filename = lang_dir / "lexicon.txt" + sil_token = "SIL" + sil_prob = 0.5 + + lexicon = read_lexicon(lexicon_filename) + tokens = get_tokens(lexicon) + words = get_words(lexicon) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in tokens + tokens.append(f"#{i}") + + assert "" not in tokens + tokens = [""] + tokens + + assert "" not in words + assert "#0" not in words + assert "" not in words + assert "" not in words + + words = [""] + words + ["#0", "", ""] + + token2id = generate_id_map(tokens) + word2id = generate_id_map(words) + + write_mapping(lang_dir / "tokens.txt", token2id) + write_mapping(lang_dir / "words.txt", word2id) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst( + lexicon, + token2id=token2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + ) + + L_disambig = lexicon_to_fst( + lexicon_disambig, + token2id=token2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py b/egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py new file mode 100755 index 000000000..415bdff6f --- /dev/null +++ b/egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +""" + +This script takes as input `lang_dir`, which should contain:: + + - lang_dir/bpe.model, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import sentencepiece as spm +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + +from icefall.utils import str2bool + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def generate_otc_lexicon( + model_file: str, + words: List[str], + oov: str, + otc_token: str, +) -> Tuple[Lexicon, Dict[str, int]]: + """Generate a lexicon from a BPE model. + + Args: + model_file: + Path to a sentencepiece model. + words: + A list of strings representing words. + oov: + The out of vocabulary word in lexicon. + otc_token: + The OTC token in lexicon. + Returns: + Return a tuple with two elements: + - A dict whose keys are words and values are the corresponding + word pieces. + - A dict representing the token symbol, mapping from tokens to IDs. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + + # Convert word to word piece IDs instead of word piece strings + # to avoid OOV tokens. + words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) + + # Now convert word piece IDs back to word piece strings. + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + + lexicon = [] + for word, pieces in zip(words, words_pieces): + lexicon.append((word, pieces)) + + lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} + + # Add OTC token to the last. + lexicon.append((otc_token, [f"▁{otc_token}"])) + otc_token_index = len(token2id) + token2id[f"▁{otc_token}"] = otc_token_index + + return lexicon, token2id + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + parser.add_argument( + "--oov", + type=str, + default="", + help="The out of vocabulary word in lexicon.", + ) + + parser.add_argument( + "--otc-token", + type=str, + default="", + help="The OTC token in lexicon.", + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + + See "test/test_bpe_lexicon.py" for usage. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + model_file = lang_dir / "bpe.model" + otc_token = args.otc_token + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = [ + "", + "!SIL", + "", + args.oov, + otc_token, + "#0", + "", + "", + ] + + for w in excluded: + if w in words: + words.remove(w) + + lexicon, token_sym_table = generate_otc_lexicon( + model_file, words, args.oov, otc_token + ) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/local/train_bpe_model.py b/egs/librispeech/WSASR/local/train_bpe_model.py new file mode 100755 index 000000000..43142aee4 --- /dev/null +++ b/egs/librispeech/WSASR/local/train_bpe_model.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/local/validate_bpe_lexicon.py b/egs/librispeech/WSASR/local/validate_bpe_lexicon.py new file mode 100755 index 000000000..16a489c11 --- /dev/null +++ b/egs/librispeech/WSASR/local/validate_bpe_lexicon.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks that there are no OOV tokens in the BPE-based lexicon. + +Usage example: + + python3 ./local/validate_bpe_lexicon.py \ + --lexicon /path/to/lexicon.txt \ + --bpe-model /path/to/bpe.model +""" + +import argparse +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm + +from icefall.lexicon import read_lexicon + +# Map word to word pieces +Lexicon = List[Tuple[str, List[str]]] + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--lexicon", + required=True, + type=Path, + help="Path to lexicon.txt", + ) + + parser.add_argument( + "--bpe-model", + required=True, + type=Path, + help="Path to bpe.model", + ) + + parser.add_argument( + "--otc-token", + required=True, + type=str, + help="OTC token", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + assert args.lexicon.is_file(), args.lexicon + assert args.bpe_model.is_file(), args.bpe_model + + lexicon = read_lexicon(args.lexicon) + + sp = spm.SentencePieceProcessor() + sp.load(str(args.bpe_model)) + + word_pieces = set(sp.id_to_piece(list(range(sp.vocab_size())))) + word_pieces.add(f"▁{args.otc_token}") + for word, pieces in lexicon: + for p in pieces: + if p not in word_pieces: + raise ValueError(f"The word {word} contains an OOV token {p}") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/local/validate_manifest.py b/egs/librispeech/WSASR/local/validate_manifest.py new file mode 100755 index 000000000..f620b91ea --- /dev/null +++ b/egs/librispeech/WSASR/local/validate_manifest.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut +- Supervision time bounds are within cut time bounds + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def validate_one_supervision_per_cut(c: Cut): + if len(c.supervisions) != 1: + raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") + + +def validate_supervision_and_cut_time_bounds(c: Cut): + s = c.supervisions[0] + if s.start < c.start: + raise ValueError( + f"{c.id}: Supervision start time {s.start} is less " + f"than cut start time {c.start}" + ) + + if s.end > c.end: + raise ValueError( + f"{c.id}: Supervision end time {s.end} is larger " + f"than cut end time {c.end}" + ) + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet) + + for c in cut_set: + validate_one_supervision_per_cut(c) + validate_supervision_and_cut_time_bounds(c) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/WSASR/prepare.sh b/egs/librispeech/WSASR/prepare.sh new file mode 100755 index 000000000..f6a922fde --- /dev/null +++ b/egs/librispeech/WSASR/prepare.sh @@ -0,0 +1,233 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/LibriSpeech +# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it. +# You can download them from https://www.openslr.org/12 +# +# - $dl_dir/lm +# This directory contains the following files downloaded from +# http://www.openslr.org/resources/11 +# +# - 3-gram.pruned.1e-7.arpa.gz +# - 3-gram.pruned.1e-7.arpa +# - 4-gram.arpa.gz +# - 4-gram.arpa +# - librispeech-vocab.txt +# - librispeech-lexicon.txt +# - librispeech-lm-norm.txt.gz +# +otc_token="" +feature_type="ssl" + +dl_dir=$PWD/download +manifests_dir="data/manifests" +feature_dir="data/${feature_type}" +lang_dir="data/lang" +lm_dir="data/lm" + +perturb_speed=false + +# ssl or fbank + +. ./cmd.sh +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + 200 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: ${dl_dir}" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download LM" + mkdir -p ${dl_dir}/lm + if [ ! -e ${dl_dir}/lm/.done ]; then + ./local/download_lm.py --out-dir=${dl_dir}/lm + touch ${dl_dir}/lm/.done + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/LibriSpeech, + # you can create a symlink + # + # ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech + # + if [ ! -d $dl_dir/LibriSpeech/train-clean-100 ]; then + lhotse download librispeech --full ${dl_dir} + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LibriSpeech manifest" + # We assume that you have downloaded the LibriSpeech corpus + # to $dl_dir/LibriSpeech + mkdir -p data/manifests + if [ ! -e data/manifests/.librispeech.done ]; then + lhotse prepare librispeech -j ${nj} \ + -p dev-clean \ + -p dev-other \ + -p test-clean \ + -p test-other \ + -p train-clean-100 "${dl_dir}/LibriSpeech" "${manifests_dir}" + touch data/manifests/.librispeech.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute ${feature_type} feature for librispeech (train-clean-100)" + mkdir -p "${feature_dir}" + if [ ! -e "${feature_dir}/.librispeech.done" ]; then + if [ "${feature_type}" = ssl ]; then + ./local/compute_ssl_librispeech.py + elif [ "${feature_type}" = fbank ]; then + ./local/compute_fbank_librispeech.py --perturb-speed ${perturb_speed} + else + log "Error: not supported --feature-type '${feature_type}'" + exit 2 + fi + + touch "${feature_dir}.librispeech.done" + fi + + if [ ! -e "${feature_dir}/.librispeech-validated.done" ]; then + log "Validating data/ssl for LibriSpeech" + parts=( + train-clean-100 + test-clean + test-other + dev-clean + dev-other + ) + for part in ${parts[@]}; do + python3 ./local/validate_manifest.py \ + "${feature_dir}/librispeech_cuts_${part}.jsonl.gz" + done + touch "${feature_dir}/.librispeech-validated.done" + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare words.txt" + mkdir -p ${lang_dir} + + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/lm/librispeech-lexicon.txt | + sort | uniq > ${lang_dir}/lexicon.txt + + local/get_words_from_lexicon.py \ + --lang-dir ${lang_dir} \ + --otc-token ${otc_token} +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + bpe_lang_dir="data/lang_bpe_${vocab_size}" + mkdir -p "${bpe_lang_dir}" + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp "${lang_dir}/words.txt" "${bpe_lang_dir}" + + if [ ! -f "${bpe_lang_dir}/transcript_words.txt" ]; then + log "Generate data for BPE training" + files=$( + find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > "${bpe_lang_dir}/transcript_words.txt" + fi + + if [ ! -f ${bpe_lang_dir}/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir ${bpe_lang_dir} \ + --vocab-size ${vocab_size} \ + --transcript ${bpe_lang_dir}/transcript_words.txt + fi + + if [ ! -f ${bpe_lang_dir}/L_disambig.pt ]; then + ./local/prepare_otc_lang_bpe.py \ + --lang-dir "${bpe_lang_dir}" \ + --otc-token "${otc_token}" + + log "Validating ${bpe_lang_dir}/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon ${bpe_lang_dir}/lexicon.txt \ + --bpe-model ${bpe_lang_dir}/bpe.model \ + --otc-token "${otc_token}" + fi + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare G" + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p "${lm_dir}" + if [ ! -f ${lm_dir}/G_3_gram.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="${lang_dir}/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + ${dl_dir}/lm/3-gram.pruned.1e-7.arpa > ${lm_dir}/G_3_gram.fst.txt + fi + + if [ ! -f ${lm_dir}/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + python3 -m kaldilm \ + --read-symbol-table="${lang_dir}/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + ${dl_dir}/lm/4-gram.arpa > ${lm_dir}/G_4_gram.fst.txt + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Compile HLG" + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + bpe_lang_dir="data/lang_bpe_${vocab_size}" + echo "LM DIR: ${lm_dir}" + ./local/compile_hlg.py \ + --lm-dir "${lm_dir}" \ + --lang-dir "${bpe_lang_dir}" + done +fi diff --git a/icefall/otc_graph_compiler.py b/icefall/otc_graph_compiler.py new file mode 100644 index 000000000..bfd679452 --- /dev/null +++ b/icefall/otc_graph_compiler.py @@ -0,0 +1,246 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path +from typing import List, Union + +import k2 +import sentencepiece as spm +import torch + +from icefall.utils import str2bool + + +class OtcTrainingGraphCompiler(object): + def __init__( + self, + lang_dir: Path, + otc_token: str, + device: Union[str, torch.device] = "cpu", + sos_token: str = "", + eos_token: str = "", + initial_bypass_weight: float = 0.0, + initial_self_loop_weight: float = 0.0, + bypass_weight_decay: float = 0.0, + self_loop_weight_decay: float = 0.0, + ) -> None: + """ + Args: + lang_dir: + This directory is expected to contain the following files: + + - bpe.model + - words.txt + otc_token: + The special token in OTC that represent all non-blank tokens + device: + It indicates CPU or CUDA. + sos_token: + The word piece that represents sos. + eos_token: + The word piece that represents eos. + """ + lang_dir = Path(lang_dir) + bpe_model_file = lang_dir / "bpe.model" + sp = spm.SentencePieceProcessor() + sp.load(str(bpe_model_file)) + self.sp = sp + self.token_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + + self.otc_token = otc_token + assert self.otc_token in self.token_table + + self.device = device + + self.sos_id = self.sp.piece_to_id(sos_token) + self.eos_id = self.sp.piece_to_id(eos_token) + + assert self.sos_id != self.sp.unk_id() + assert self.eos_id != self.sp.unk_id() + + max_token_id = self.get_max_token_id() + ctc_topo = k2.ctc_topo(max_token_id, modified=False) + self.ctc_topo = ctc_topo.to(self.device) + + self.initial_bypass_weight = initial_bypass_weight + self.initial_self_loop_weight = initial_self_loop_weight + self.bypass_weight_decay = bypass_weight_decay + self.self_loop_weight_decay = self_loop_weight_decay + + def get_max_token_id(self): + max_token_id = 0 + for symbol in self.token_table.symbols: + if not symbol.startswith("#"): + max_token_id = max(self.token_table[symbol], max_token_id) + assert max_token_id > 0 + + return max_token_id + + def make_arc( + self, + from_state: int, + to_state: int, + symbol: Union[str, int], + weight: float, + ): + return f"{from_state} {to_state} {symbol} {weight}" + + def texts_to_ids(self, texts: List[str]) -> List[List[int]]: + """Convert a list of texts to a list-of-list of piece IDs. + + Args: + texts: + It is a list of strings. Each string consists of space(s) + separated words. An example containing two strings is given below: + + ['HELLO ICEFALL', 'HELLO k2'] + Returns: + Return a list-of-list of piece IDs. + """ + return self.sp.encode(texts, out_type=int) + + def compile( + self, + texts: List[str], + allow_bypass_arc: str2bool = True, + allow_self_loop_arc: str2bool = True, + bypass_weight: float = 0.0, + self_loop_weight: float = 0.0, + ) -> k2.Fsa: + """Build a OTC graph from a texts (list of words). + + Args: + texts: + A list of strings. Each string contains a sentence for an utterance. + A sentence consists of spaces separated words. An example `texts` + looks like: + ['hello icefall', 'CTC training with k2'] + allow_bypass_arc: + Whether to add bypass arc to training graph for substitution + and insertion errors (wrong or extra words in the transcript). + allow_self_loop_arc: + Whether to add self-loop arc to training graph for deletion + errors (missing words in the transcript). + bypass_weight: + Weight associated with bypass arc. + self_loop_weight: + Weight associated with self-loop arc. + + Return: + Return an FsaVec, which is the result of composing a + CTC topology with OTC FSAs constructed from the given texts. + """ + + transcript_fsa = self.convert_transcript_to_fsa( + texts, + self.otc_token, + allow_bypass_arc, + allow_self_loop_arc, + bypass_weight, + self_loop_weight, + ) + transcript_fsa = transcript_fsa.to(self.device) + fsa_with_self_loop = k2.remove_epsilon_and_add_self_loops(transcript_fsa) + fsa_with_self_loop = k2.arc_sort(fsa_with_self_loop) + + graph = k2.compose( + self.ctc_topo, + fsa_with_self_loop, + treat_epsilons_specially=False, + ) + assert graph.requires_grad is False + + return graph + + def convert_transcript_to_fsa( + self, + texts: List[str], + otc_token: str, + allow_bypass_arc: str2bool = True, + allow_self_loop_arc: str2bool = True, + bypass_weight: float = 0.0, + self_loop_weight: float = 0.0, + ): + otc_token_id = self.token_table[otc_token] + + transcript_fsa_list = [] + for text in texts: + text_piece_ids = [] + + for word in text.split(): + piece_ids = self.sp.encode(word, out_type=int) + text_piece_ids.append(piece_ids) + + arcs = [] + start_state = 0 + cur_state = start_state + next_state = 1 + + for piece_ids in text_piece_ids: + bypass_cur_state = cur_state + + if allow_self_loop_arc: + self_loop_arc = self.make_arc( + cur_state, + cur_state, + otc_token_id, + self_loop_weight, + ) + arcs.append(self_loop_arc) + + for piece_id in piece_ids: + arc = self.make_arc(cur_state, next_state, piece_id, 0.0) + arcs.append(arc) + + cur_state = next_state + next_state += 1 + + bypass_next_state = cur_state + if allow_bypass_arc: + bypass_arc = self.make_arc( + bypass_cur_state, + bypass_next_state, + otc_token_id, + bypass_weight, + ) + arcs.append(bypass_arc) + bypass_cur_state = cur_state + + if allow_self_loop_arc: + self_loop_arc = self.make_arc( + cur_state, + cur_state, + otc_token_id, + self_loop_weight, + ) + arcs.append(self_loop_arc) + + # Deal with final state + final_state = next_state + final_arc = self.make_arc(cur_state, final_state, -1, 0.0) + arcs.append(final_arc) + arcs.append(f"{final_state}") + sorted_arcs = sorted(arcs, key=lambda a: int(a.split()[0])) + + transcript_fsa = k2.Fsa.from_str("\n".join(sorted_arcs)) + transcript_fsa = k2.arc_sort(transcript_fsa) + transcript_fsa_list.append(transcript_fsa) + + transcript_fsa_vec = k2.create_fsa_vec(transcript_fsa_list) + + return transcript_fsa_vec diff --git a/icefall/utils.py b/icefall/utils.py index 947d79438..8fda3a4ca 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -263,6 +263,70 @@ def get_texts( return aux_labels.tolist() +def encode_supervisions_otc( + supervisions: dict, + subsampling_factor: int, + token_ids: Optional[List[List[int]]] = None, +) -> Tuple[torch.Tensor, Union[List[str], List[List[int]]]]: + """ + Encodes Lhotse's ``batch["supervisions"]`` dict into + a pair of torch Tensor, and a list of transcription strings or token indexes + + The supervision tensor has shape ``(batch_size, 3)``. + Its second dimension contains information about sequence index [0], + start frames [1] and num frames [2]. + + The batch items might become re-ordered during this operation -- the + returned tensor and list of strings are guaranteed to be consistent with + each other. + """ + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + indices = torch.argsort(supervision_segments[:, 2], descending=True) + supervision_segments = supervision_segments[indices] + + ids = [] + verbatim_texts = [] + sorted_ids = [] + sorted_verbatim_texts = [] + + for cut in supervisions["cut"]: + id = cut.id + if hasattr(cut.supervisions[0], "verbatim_text"): + verbatim_text = cut.supervisions[0].verbatim_text + else: + verbatim_text = "" + ids.append(id) + verbatim_texts.append(verbatim_text) + + for index in indices.tolist(): + sorted_ids.append(ids[index]) + sorted_verbatim_texts.append(verbatim_texts[index]) + + if token_ids is None: + texts = supervisions["text"] + res = [texts[idx] for idx in indices] + else: + res = [token_ids[idx] for idx in indices] + + return supervision_segments, res, sorted_ids, sorted_verbatim_texts + + @dataclass class DecodingResults: # timestamps[i][k] contains the frame number on which tokens[i][k] From 48cc41bd832f4d920e1ae8ff5beb97d1d7fc1f89 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 30 Sep 2023 22:23:22 +0800 Subject: [PATCH 052/113] Fix CI --- requirements-ci.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-ci.txt b/requirements-ci.txt index 6f8739ce0..1eba69764 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -27,4 +27,4 @@ onnx onnxmltools onnxruntime kaldifst -kaldi-hmm-gmm +kaldi-decoder From f14b6734089c1eaa5b02c34c426bc46fbd37c6b0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 1 Oct 2023 13:46:16 +0800 Subject: [PATCH 053/113] Add HLG decoding with OpenFst on CPU for aishell conformer_ctc (#1279) --- .../scripts/run-pre-trained-conformer-ctc.sh | 80 ++++++++++++++++++- .github/workflows/run-yesno-recipe.yml | 2 +- egs/aishell/ASR/conformer_ctc/export.py | 21 +++-- .../jit_pretrained_decode_with_H.py | 1 + .../jit_pretrained_decode_with_HL.py | 1 + .../jit_pretrained_decode_with_HLG.py | 1 + .../ASR/conformer_ctc/test_transformer.py | 0 egs/aishell/ASR/local/prepare_lang_fst.py | 1 + egs/aishell/ASR/prepare.sh | 5 ++ .../jit_pretrained_decode_with_H.py | 16 +++- .../jit_pretrained_decode_with_HL.py | 16 +++- .../jit_pretrained_decode_with_HLG.py | 15 +++- .../ASR/tdnn/jit_pretrained_decode_with_H.py | 2 +- .../ASR/tdnn/jit_pretrained_decode_with_HL.py | 2 +- icefall/ctc/README.md | 6 +- requirements.txt | 2 +- 16 files changed, 146 insertions(+), 25 deletions(-) mode change 100644 => 100755 egs/aishell/ASR/conformer_ctc/export.py create mode 120000 egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_H.py create mode 120000 egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py create mode 120000 egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py mode change 100644 => 100755 egs/aishell/ASR/conformer_ctc/test_transformer.py create mode 120000 egs/aishell/ASR/local/prepare_lang_fst.py diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index a82d85fb2..ea400c628 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -8,7 +8,7 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } -cd egs/librispeech/ASR +pushd egs/librispeech/ASR # repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09 @@ -112,3 +112,81 @@ log "Decoding with HLG on CPU with OpenFst" $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav + +rm -rf $repo + +popd + +log "Test aishell" + +pushd egs/aishell/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall_asr_aishell_conformer_ctc +log "Downloading pre-trained model from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) +pushd $repo + +git lfs pull --include "exp/pretrained.pt" +git lfs pull --include "data/lm/G_3_gram_char.fst.txt" + +popd + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +log "CTC decoding" + +log "Exporting model with torchscript" + +pushd $repo/exp +ln -s pretrained.pt epoch-99.pt +popd + +./conformer_ctc/export.py \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_char/tokens.txt \ + --jit 1 + +ls -lh $repo/exp + +log "Generating H.fst, HL.fst" + +./local/prepare_lang_fst.py --lang-dir $repo/data/lang_char --ngram-G $repo/data/lm/G_3_gram_char.fst.txt + +ls -lh $repo/data/lang_char + +log "Decoding with H on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --H $repo/data/lang_char/H.fst \ + --tokens $repo/data/lang_char/tokens.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +log "Decoding with HL on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_HL.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HL $repo/data/lang_char/HL.fst \ + --words $repo/data/lang_char/words.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +log "Decoding with HLG on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HLG $repo/data/lang_char/HLG.fst \ + --words $repo/data/lang_char/words.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +rm -rf $repo diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 400595749..7d55a50e1 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -60,7 +60,7 @@ jobs: - name: Install Python dependencies run: | - grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf pip install --no-binary protobuf protobuf==3.20.* diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/aishell/ASR/conformer_ctc/export.py old mode 100644 new mode 100755 index 1df3cfdc2..49871d437 --- a/egs/aishell/ASR/conformer_ctc/export.py +++ b/egs/aishell/ASR/conformer_ctc/export.py @@ -23,12 +23,12 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -63,11 +63,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="""It contains language related input files such as "lexicon.txt" - """, + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -98,16 +97,16 @@ def get_params() -> AttributeDict: def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) - logging.info(params) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + num_classes = num_tokens(token_table) + 1 # +1 for the blank + + logging.info(params) device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_H.py b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_H.py new file mode 120000 index 000000000..896b78aef --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_H.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py new file mode 120000 index 000000000..aa1b6073d --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py new file mode 120000 index 000000000..0cf42ce30 --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_ctc/test_transformer.py b/egs/aishell/ASR/conformer_ctc/test_transformer.py old mode 100644 new mode 100755 diff --git a/egs/aishell/ASR/local/prepare_lang_fst.py b/egs/aishell/ASR/local/prepare_lang_fst.py new file mode 120000 index 000000000..c5787c534 --- /dev/null +++ b/egs/aishell/ASR/local/prepare_lang_fst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index ff8e1301d..9de060e73 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -143,6 +143,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then ./local/prepare_lang.py --lang-dir $lang_phone_dir fi + # Train a bigram P for MMI training if [ ! -f $lang_phone_dir/transcript_words.txt ]; then log "Generate data to train phone based bigram P" @@ -203,6 +204,10 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ ! -f $lang_char_dir/L_disambig.pt ]; then ./local/prepare_char.py --lang-dir $lang_char_dir fi + + if [ ! -f $lang_char_dir/HLG.fst ]; then + ./local/prepare_lang_fst.py --lang-dir $lang_phone_dir --ngram-G ./data/lm/G_3_gram.fst.txt + fi fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py index b52c7cfed..8dd856a4e 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py @@ -7,6 +7,8 @@ on CPU using OpenFST and decoders from kaldi. Usage: +(1) LibriSpeech conformer_ctc + ./conformer_ctc/jit_pretrained_decode_with_H.py \ --nn-model ./conformer_ctc/exp/cpu_jit.pt \ --H ./data/lang_bpe_500/H.fst \ @@ -14,6 +16,17 @@ Usage: ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + +(2) AIShell conformer_ctc + + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --H ./data/lang_char/H.fst \ + --tokens ./data/lang_char/tokens.txt \ + ./BAC009S0764W0121.wav \ + ./BAC009S0764W0122.wav \ + ./BAC009S0764W0123.wav + Note that to generate ./conformer_ctc/exp/cpu_jit.pt, you can use ./export.py --jit 1 """ @@ -23,12 +36,11 @@ import logging import math from typing import Dict, List -import kaldi_hmm_gmm import kaldifeat import kaldifst import torch import torchaudio -from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions from torch.nn.utils.rnn import pad_sequence diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py index 3420c4da3..796e19661 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py @@ -7,6 +7,8 @@ on CPU using OpenFST and decoders from kaldi. Usage: +(1) LibriSpeech conformer_ctc + ./conformer_ctc/jit_pretrained_decode_with_HL.py \ --nn-model ./conformer_ctc/exp/cpu_jit.pt \ --HL ./data/lang_bpe_500/HL.fst \ @@ -14,6 +16,17 @@ Usage: ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac +(2) AIShell conformer_ctc + + ./conformer_ctc/jit_pretrained_decode_with_HL.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --HL ./data/lang_char/HL.fst \ + --words ./data/lang_char/words.txt \ + ./BAC009S0764W0121.wav \ + ./BAC009S0764W0122.wav \ + ./BAC009S0764W0123.wav + + Note that to generate ./conformer_ctc/exp/cpu_jit.pt, you can use ./export.py --jit 1 """ @@ -23,12 +36,11 @@ import logging import math from typing import Dict, List -import kaldi_hmm_gmm import kaldifeat import kaldifst import torch import torchaudio -from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions from torch.nn.utils.rnn import pad_sequence diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py index 42129f073..0024d5c9c 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py @@ -7,6 +7,8 @@ on CPU using OpenFST and decoders from kaldi. Usage: +(1) LibriSpeech conformer_ctc + ./conformer_ctc/jit_pretrained_decode_with_HLG.py \ --nn-model ./conformer_ctc/exp/cpu_jit.pt \ --HLG ./data/lang_bpe_500/HLG.fst \ @@ -14,6 +16,16 @@ Usage: ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac +(2) AIShell conformer_ctc + + ./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --HLG ./data/lang_char/HLG.fst \ + --words ./data/lang_char/words.txt \ + ./BAC009S0764W0121.wav \ + ./BAC009S0764W0122.wav \ + ./BAC009S0764W0123.wav + Note that to generate ./conformer_ctc/exp/cpu_jit.pt, you can use ./export.py --jit 1 """ @@ -23,12 +35,11 @@ import logging import math from typing import Dict, List -import kaldi_hmm_gmm import kaldifeat import kaldifst import torch import torchaudio -from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions from torch.nn.utils.rnn import pad_sequence diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py index 209ab477a..ff8c742af 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py @@ -28,7 +28,7 @@ import kaldifeat import kaldifst import torch import torchaudio -from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions from torch.nn.utils.rnn import pad_sequence diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py index 74864e17d..05ba74f9a 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py @@ -28,7 +28,7 @@ import kaldifeat import kaldifst import torch import torchaudio -from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions from torch.nn.utils.rnn import pad_sequence diff --git a/icefall/ctc/README.md b/icefall/ctc/README.md index 07b0ff8cd..0096bc096 100644 --- a/icefall/ctc/README.md +++ b/icefall/ctc/README.md @@ -1,17 +1,17 @@ # Introduction This folder uses [kaldifst][kaldifst] for graph construction -and decoders from [kaldi-hmm-gmm][kaldi-hmm-gmm] for CTC decoding. +and decoders from [kaldi-decoder][kaldi-decoder] for CTC decoding. It supports only `CPU`. You can use ```bash -pip install kaldifst kaldi-hmm-gmm +pip install kaldifst kaldi-decoder ``` to install the dependencies. -[kaldi-hmm-gmm]: https://github.com/csukuangfj/kaldi-hmm-gmm +[kaldi-decoder]: https://github.com/i2-fsa/kaldi-decoder [kaldifst]: https://github.com/k2-fsa/kaldifst [k2]: https://github.com/k2-fsa/k2 diff --git a/requirements.txt b/requirements.txt index c031d683c..5a8326619 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ kaldifst kaldilm kaldialign -kaldi-hmm-gmm +kaldi-decoder sentencepiece>=0.1.96 tensorboard typeguard From 109354b6b8199fa27cd8d4310b59a2e45da1d537 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 2 Oct 2023 14:00:06 +0800 Subject: [PATCH 054/113] Add CTC HLG decoding for zipformer (#1287) --- ...onformer-ctc.sh => run-pre-trained-ctc.sh} | 60 ++- ...nformer-ctc.yml => run-pretrained-ctc.yml} | 10 +- .../jit_pretrained_decode_with_H.py | 4 +- .../jit_pretrained_decode_with_HL.py | 8 +- .../jit_pretrained_decode_with_HLG.py | 8 +- .../ASR/zipformer/export-onnx-ctc.py | 436 ++++++++++++++++++ .../ASR/zipformer/onnx_pretrained_ctc.py | 213 +++++++++ .../ASR/zipformer/onnx_pretrained_ctc_H.py | 277 +++++++++++ .../ASR/zipformer/onnx_pretrained_ctc_HL.py | 275 +++++++++++ .../ASR/zipformer/onnx_pretrained_ctc_HLG.py | 275 +++++++++++ 10 files changed, 1545 insertions(+), 21 deletions(-) rename .github/scripts/{run-pre-trained-conformer-ctc.sh => run-pre-trained-ctc.sh} (79%) rename .github/workflows/{run-pretrained-conformer-ctc.yml => run-pretrained-ctc.yml} (91%) create mode 100755 egs/librispeech/ASR/zipformer/export-onnx-ctc.py create mode 100755 egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py create mode 100755 egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py create mode 100755 egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py create mode 100755 egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-ctc.sh similarity index 79% rename from .github/scripts/run-pre-trained-conformer-ctc.sh rename to .github/scripts/run-pre-trained-ctc.sh index ea400c628..7d6449c9a 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-ctc.sh @@ -10,7 +10,57 @@ log() { pushd egs/librispeech/ASR -# repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +log "CTC greedy search" + +./zipformer/onnx_pretrained_ctc.py \ + --nn-model $repo/model.onnx \ + --tokens $repo/tokens.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +log "CTC H decoding" + +./zipformer/onnx_pretrained_ctc_H.py \ + --nn-model $repo/model.onnx \ + --tokens $repo/tokens.txt \ + --H $repo/H.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +log "CTC HL decoding" + +./zipformer/onnx_pretrained_ctc_HL.py \ + --nn-model $repo/model.onnx \ + --words $repo/words.txt \ + --HL $repo/HL.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +log "CTC HLG decoding" + +./zipformer/onnx_pretrained_ctc_HLG.py \ + --nn-model $repo/model.onnx \ + --words $repo/words.txt \ + --HLG $repo/HLG.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +rm -rf $repo + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09 log "Downloading pre-trained model from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url @@ -128,7 +178,9 @@ repo=$(basename $repo_url) pushd $repo git lfs pull --include "exp/pretrained.pt" -git lfs pull --include "data/lm/G_3_gram_char.fst.txt" +git lfs pull --include "data/lang_char/H.fst" +git lfs pull --include "data/lang_char/HL.fst" +git lfs pull --include "data/lang_char/HLG.fst" popd @@ -153,10 +205,6 @@ popd ls -lh $repo/exp -log "Generating H.fst, HL.fst" - -./local/prepare_lang_fst.py --lang-dir $repo/data/lang_char --ngram-G $repo/data/lm/G_3_gram_char.fst.txt - ls -lh $repo/data/lang_char log "Decoding with H on CPU with OpenFst" diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-ctc.yml similarity index 91% rename from .github/workflows/run-pretrained-conformer-ctc.yml rename to .github/workflows/run-pretrained-ctc.yml index 54845159d..074a63dfc 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-ctc.yml @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: run-pre-trained-conformer-ctc +name: run-pre-trained-ctc on: push: @@ -31,12 +31,12 @@ on: default: 'y' concurrency: - group: run_pre_trained_conformer_ctc-${{ github.ref }} + group: run_pre_trained_ctc-${{ github.ref }} cancel-in-progress: true jobs: - run_pre_trained_conformer_ctc: - if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y' + run_pre_trained_ctc: + if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y' || github.event.label.name == 'ctc' runs-on: ${{ matrix.os }} strategy: matrix: @@ -84,4 +84,4 @@ jobs: export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-pre-trained-conformer-ctc.sh + .github/scripts/run-pre-trained-ctc.sh diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py index 8dd856a4e..4bdec9e11 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py @@ -145,7 +145,7 @@ def decode( decoder.decode(decodable) if not decoder.reached_final(): - print(f"failed to decode {filename}") + logging.info(f"failed to decode {filename}") return [""] ok, best_path = decoder.get_best_path() @@ -157,7 +157,7 @@ def decode( total_weight, ) = kaldifst.get_linear_symbol_sequence(best_path) if not ok: - print(f"failed to get linear symbol sequence for {filename}") + logging.info(f"failed to get linear symbol sequence for {filename}") return [""] # tokens are incremented during graph construction diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py index 796e19661..d5a1dba3c 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py @@ -132,8 +132,8 @@ def decode( contains output from log_softmax. HL: The HL graph. - word2token: - A map mapping token ID to word string. + id2word: + A map mapping word ID to word string. Returns: Return a list of decoded words. """ @@ -145,7 +145,7 @@ def decode( decoder.decode(decodable) if not decoder.reached_final(): - print(f"failed to decode {filename}") + logging.info(f"failed to decode {filename}") return [""] ok, best_path = decoder.get_best_path() @@ -157,7 +157,7 @@ def decode( total_weight, ) = kaldifst.get_linear_symbol_sequence(best_path) if not ok: - print(f"failed to get linear symbol sequence for {filename}") + logging.info(f"failed to get linear symbol sequence for {filename}") return [""] # are shifted by 1 during graph construction diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py index 0024d5c9c..216677a23 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py @@ -131,8 +131,8 @@ def decode( contains output from log_softmax. HLG: The HLG graph. - word2token: - A map mapping token ID to word string. + id2word: + A map mapping word ID to word string. Returns: Return a list of decoded words. """ @@ -144,7 +144,7 @@ def decode( decoder.decode(decodable) if not decoder.reached_final(): - print(f"failed to decode {filename}") + logging.info(f"failed to decode {filename}") return [""] ok, best_path = decoder.get_best_path() @@ -156,7 +156,7 @@ def decode( total_weight, ) = kaldifst.get_linear_symbol_sequence(best_path) if not ok: - print(f"failed to get linear symbol sequence for {filename}") + logging.info(f"failed to get linear symbol sequence for {filename}") return [""] # are shifted by 1 during graph construction diff --git a/egs/librispeech/ASR/zipformer/export-onnx-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py new file mode 100755 index 000000000..3345d20d3 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a CTC model from PyTorch to ONNX. + +Note that the model is trained using both transducer and CTC loss. This script +exports only the CTC head. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx-ctc.py \ + --use-transducer 0 \ + --use-ctc 1 \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size 16 \ + --left-context-frames 128 + +It will generate the following 2 files inside $repo/exp: + + - model.onnx + - model.int8.onnx + +See ./onnx_pretrained_ctc.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for encoder_embed, Zipformer, and ctc_output layer""" + + def __init__( + self, + encoder: Zipformer2, + encoder_embed: nn.Module, + ctc_output: nn.Module, + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_embed: + The first downsampling layer for zipformer. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.ctc_output = ctc_output + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - log_probs, a 3-D tensor of shape (N, T', vocab_size) + - log_probs_len, a 1-D int64 tensor of shape (N,) + """ + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) + encoder_out, log_probs_len = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) + log_probs = self.ctc_output(encoder_out) + + return log_probs, log_probs_len + + +def export_ctc_model_onnx( + model: OnnxModel, + filename: str, + opset_version: int = 11, +) -> None: + """Export the given model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - log_probs, a tensor of shape (N, T', joiner_dim) + - log_probs_len, a tensor of shape (N,) + + Args: + model: + The input model + filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + model = torch.jit.trace(model, (x, x_lens)) + + torch.onnx.export( + model, + (x, x_lens), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["log_probs", "log_probs_len"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "log_probs": {0: "N", 1: "T"}, + "log_probs_len": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer2_ctc", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming zipformer2 CTC", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint( + f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False + ) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + model = OnnxModel( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + ctc_output=model.ctc_output, + ) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"num parameters: {num_param}") + + opset_version = 13 + + logging.info("Exporting ctc model") + filename = params.exp_dir / f"model.onnx" + export_ctc_model_onnx( + model, + filename, + opset_version=opset_version, + ) + logging.info(f"Exported to {filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + filename_int8 = params.exp_dir / f"model.int8.onnx" + quantize_dynamic( + model_input=filename, + model_output=filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py new file mode 100755 index 000000000..eb5cee9cd --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zipformer/onnx_pretrained_ctc.py \ + --nn-model /path/to/model.onnx \ + --tokens /path/to/data/lang_bpe_500/tokens.txt \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + token_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + blank_id = 0 + s = "\n" + for i in range(log_probs.size(0)): + # greedy search + indexes = log_probs[i, : log_probs_len[i]].argmax(dim=-1) + token_ids = torch.unique_consecutive(indexes) + + token_ids = token_ids[token_ids != blank_id] + words = token_ids_to_words(token_ids.tolist()) + s += f"{args.sound_files[i]}:\n{words}\n\n" + + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py new file mode 100755 index 000000000..683a7dc20 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zipformer/onnx_pretrained_ctc_H.py \ + --nn-model /path/to/model.onnx \ + --tokens /path/to/data/lang_bpe_500/tokens.txt \ + --H /path/to/H.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +from typing import Dict +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--H", + type=str, + help="""Path to H.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + H: kaldifst, + id2token: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + H: + The H graph. + id2word: + A map mapping token ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(H, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # tokens are incremented during graph construction + # are shifted by 1 during graph construction + hyps = [id2token[i - 1] for i in osymbols_out if i != 1] + hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁ + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + logging.info(f"Loading H from {args.H}") + H = kaldifst.StdVectorFst.read(args.H) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + token_table = k2.SymbolTable.from_file(args.tokens) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + H=H, + id2token=token_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py new file mode 100755 index 000000000..0b94bfa65 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zipformer/onnx_pretrained_ctc_HL.py \ + --nn-model /path/to/model.onnx \ + --words /path/to/data/lang_bpe_500/words.txt \ + --HL /path/to/HL.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +from typing import Dict +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--words", + type=str, + help="""Path to words.txt.""", + ) + + parser.add_argument( + "--HL", + type=str, + help="""Path to HL.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + HL: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HL: + The HL graph. + id2word: + A map mapping word ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HL, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + logging.info(f"Loading HL from {args.HL}") + HL = kaldifst.StdVectorFst.read(args.HL) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + word_table = k2.SymbolTable.from_file(args.words) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + HL=HL, + id2word=word_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py new file mode 100755 index 000000000..93569142a --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zipformer/onnx_pretrained_ctc_HLG.py \ + --nn-model /path/to/model.onnx \ + --words /path/to/data/lang_bpe_500/words.txt \ + --HLG /path/to/HLG.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +from typing import Dict +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--words", + type=str, + help="""Path to words.txt.""", + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + HLG: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HLG: + The HLG graph. + id2word: + A map mapping word ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HLG, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + logging.info(f"Loading HLG from {args.HLG}") + HLG = kaldifst.StdVectorFst.read(args.HLG) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + word_table = k2.SymbolTable.from_file(args.words) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + HLG=HLG, + id2word=word_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() From 82199b8fe1ed77df2ff68e4edc73ee2e09baecc5 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 7 Oct 2023 11:44:18 +0800 Subject: [PATCH 055/113] Init commit for swbd (#1146) --- .../run-swbd-conformer-ctc-2023-08-26.sh | 44 + .github/workflows/run-swbd-conformer-ctc.yml | 84 ++ egs/swbd/ASR/.gitignore | 2 + egs/swbd/ASR/README.md | 25 + egs/swbd/ASR/RESULTS.md | 113 +++ egs/swbd/ASR/conformer_ctc/__init__.py | 0 egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 416 +++++++++ egs/swbd/ASR/conformer_ctc/conformer.py | 1 + egs/swbd/ASR/conformer_ctc/decode.py | 853 ++++++++++++++++++ egs/swbd/ASR/conformer_ctc/export.py | 163 ++++ egs/swbd/ASR/conformer_ctc/label_smoothing.py | 1 + egs/swbd/ASR/conformer_ctc/pretrained.py | 1 + egs/swbd/ASR/conformer_ctc/sclite_scoring.py | 148 +++ egs/swbd/ASR/conformer_ctc/subsampling.py | 1 + .../ASR/conformer_ctc/test_label_smoothing.py | 52 ++ .../ASR/conformer_ctc/test_subsampling.py | 48 + .../ASR/conformer_ctc/test_transformer.py | 1 + egs/swbd/ASR/conformer_ctc/train.py | 814 +++++++++++++++++ egs/swbd/ASR/conformer_ctc/transformer.py | 1 + egs/swbd/ASR/local/compile_hlg.py | 1 + egs/swbd/ASR/local/compile_lg.py | 1 + egs/swbd/ASR/local/compute_fbank_eval2000.py | 139 +++ egs/swbd/ASR/local/compute_fbank_swbd.py | 163 ++++ .../convert_transcript_words_to_tokens.py | 103 +++ egs/swbd/ASR/local/dict.patch | 380 ++++++++ .../ASR/local/display_manifest_statistics.py | 125 +++ egs/swbd/ASR/local/extend_segments.pl | 99 ++ egs/swbd/ASR/local/filter_cuts.py | 160 ++++ egs/swbd/ASR/local/filter_empty_text.py | 72 ++ egs/swbd/ASR/local/format_acronyms_dict.py | 118 +++ egs/swbd/ASR/local/generate_unique_lexicon.py | 98 ++ .../ASR/local/map_acronyms_transcripts.py | 60 ++ .../normalize_and_filter_supervisions.py | 283 ++++++ egs/swbd/ASR/local/normalize_eval2000.py | 234 +++++ egs/swbd/ASR/local/prepare_lang.py | 1 + egs/swbd/ASR/local/prepare_lang_bpe.py | 274 ++++++ .../ASR/local/prepare_lm_training_data.py | 1 + egs/swbd/ASR/local/rt03_data_prep.sh | 107 +++ egs/swbd/ASR/local/sort_lm_training_data.py | 141 +++ egs/swbd/ASR/local/swbd1_data_prep.sh | 128 +++ egs/swbd/ASR/local/swbd1_map_words.pl | 52 ++ egs/swbd/ASR/local/swbd1_prepare_dict.sh | 101 +++ egs/swbd/ASR/local/train_bpe_model.py | 102 +++ egs/swbd/ASR/local/validate_bpe_lexicon.py | 1 + egs/swbd/ASR/prepare.sh | 463 ++++++++++ egs/swbd/ASR/shared | 1 + egs/swbd/ASR/utils/filter_scp.pl | 87 ++ egs/swbd/ASR/utils/fix_data_dir.sh | 197 ++++ egs/swbd/ASR/utils/parse_options.sh | 97 ++ egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl | 27 + egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl | 38 + 51 files changed, 6622 insertions(+) create mode 100755 .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh create mode 100644 .github/workflows/run-swbd-conformer-ctc.yml create mode 100644 egs/swbd/ASR/.gitignore create mode 100644 egs/swbd/ASR/README.md create mode 100644 egs/swbd/ASR/RESULTS.md create mode 100644 egs/swbd/ASR/conformer_ctc/__init__.py create mode 100644 egs/swbd/ASR/conformer_ctc/asr_datamodule.py create mode 120000 egs/swbd/ASR/conformer_ctc/conformer.py create mode 100755 egs/swbd/ASR/conformer_ctc/decode.py create mode 100755 egs/swbd/ASR/conformer_ctc/export.py create mode 120000 egs/swbd/ASR/conformer_ctc/label_smoothing.py create mode 120000 egs/swbd/ASR/conformer_ctc/pretrained.py create mode 100755 egs/swbd/ASR/conformer_ctc/sclite_scoring.py create mode 120000 egs/swbd/ASR/conformer_ctc/subsampling.py create mode 100755 egs/swbd/ASR/conformer_ctc/test_label_smoothing.py create mode 100755 egs/swbd/ASR/conformer_ctc/test_subsampling.py create mode 120000 egs/swbd/ASR/conformer_ctc/test_transformer.py create mode 100755 egs/swbd/ASR/conformer_ctc/train.py create mode 120000 egs/swbd/ASR/conformer_ctc/transformer.py create mode 120000 egs/swbd/ASR/local/compile_hlg.py create mode 120000 egs/swbd/ASR/local/compile_lg.py create mode 100755 egs/swbd/ASR/local/compute_fbank_eval2000.py create mode 100755 egs/swbd/ASR/local/compute_fbank_swbd.py create mode 100755 egs/swbd/ASR/local/convert_transcript_words_to_tokens.py create mode 100644 egs/swbd/ASR/local/dict.patch create mode 100755 egs/swbd/ASR/local/display_manifest_statistics.py create mode 100755 egs/swbd/ASR/local/extend_segments.pl create mode 100755 egs/swbd/ASR/local/filter_cuts.py create mode 100755 egs/swbd/ASR/local/filter_empty_text.py create mode 100755 egs/swbd/ASR/local/format_acronyms_dict.py create mode 100755 egs/swbd/ASR/local/generate_unique_lexicon.py create mode 100755 egs/swbd/ASR/local/map_acronyms_transcripts.py create mode 100755 egs/swbd/ASR/local/normalize_and_filter_supervisions.py create mode 100755 egs/swbd/ASR/local/normalize_eval2000.py create mode 120000 egs/swbd/ASR/local/prepare_lang.py create mode 100755 egs/swbd/ASR/local/prepare_lang_bpe.py create mode 120000 egs/swbd/ASR/local/prepare_lm_training_data.py create mode 100755 egs/swbd/ASR/local/rt03_data_prep.sh create mode 100755 egs/swbd/ASR/local/sort_lm_training_data.py create mode 100755 egs/swbd/ASR/local/swbd1_data_prep.sh create mode 100755 egs/swbd/ASR/local/swbd1_map_words.pl create mode 100755 egs/swbd/ASR/local/swbd1_prepare_dict.sh create mode 100755 egs/swbd/ASR/local/train_bpe_model.py create mode 120000 egs/swbd/ASR/local/validate_bpe_lexicon.py create mode 100755 egs/swbd/ASR/prepare.sh create mode 120000 egs/swbd/ASR/shared create mode 100755 egs/swbd/ASR/utils/filter_scp.pl create mode 100755 egs/swbd/ASR/utils/fix_data_dir.sh create mode 100755 egs/swbd/ASR/utils/parse_options.sh create mode 100755 egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl create mode 100755 egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl diff --git a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh new file mode 100755 index 000000000..d8cc020e1 --- /dev/null +++ b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/swbd/ASR + +repo_url=https://huggingface.co/zrjin/icefall-asr-swbd-conformer-ctc-2023-8-26 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s epoch-98.pt epoch-99.pt +popd + +ls -lh $repo/exp/*.pt + +for method in ctc-decoding 1best; do + log "$method" + + ./conformer_ctc/pretrained.py \ + --method $method \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --G $repo/data/lm/G_4_gram.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done diff --git a/.github/workflows/run-swbd-conformer-ctc.yml b/.github/workflows/run-swbd-conformer-ctc.yml new file mode 100644 index 000000000..842691d38 --- /dev/null +++ b/.github/workflows/run-swbd-conformer-ctc.yml @@ -0,0 +1,84 @@ +# Copyright 2023 Xiaomi Corp. (author: Zengrui Jin) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-swbd-conformer_ctc + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +concurrency: + group: run-swbd-conformer_ctc-${{ github.ref }} + cancel-in-progress: true + +jobs: + run-swbd-conformer_ctc: + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'swbd' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2023-05-22 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh diff --git a/egs/swbd/ASR/.gitignore b/egs/swbd/ASR/.gitignore new file mode 100644 index 000000000..11d674922 --- /dev/null +++ b/egs/swbd/ASR/.gitignore @@ -0,0 +1,2 @@ +switchboard_word_alignments.tar.gz +./swb_ms98_transcriptions/ diff --git a/egs/swbd/ASR/README.md b/egs/swbd/ASR/README.md new file mode 100644 index 000000000..13b27815a --- /dev/null +++ b/egs/swbd/ASR/README.md @@ -0,0 +1,25 @@ +# Switchboard + +The Switchboard-1 Telephone Speech Corpus (LDC97S62) consists of approximately 260 hours of speech and was originally collected by Texas Instruments in 1990-1, under DARPA sponsorship. The first release of the corpus was published by NIST and distributed by the LDC in 1992-3. Since that release, a number of corrections have been made to the data files as presented on the original CD-ROM set and all copies of the first pressing have been distributed. + +Switchboard is a collection of about 2,400 two-sided telephone conversations among 543 speakers (302 male, 241 female) from all areas of the United States. A computer-driven robot operator system handled the calls, giving the caller appropriate recorded prompts, selecting and dialing another person (the callee) to take part in a conversation, introducing a topic for discussion and recording the speech from the two subjects into separate channels until the conversation was finished. About 70 topics were provided, of which about 50 were used frequently. Selection of topics and callees was constrained so that: (1) no two speakers would converse together more than once and (2) no one spoke more than once on a given topic. + +(The above introduction is from the [LDC Switchboard-1 Release 2 webpage](https://catalog.ldc.upenn.edu/LDC97S62).) + + +## Performance Record +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 33.37 | 35.06 | + +See [RESULTS](/egs/swbd/ASR/RESULTS.md) for details. + +## Credit + +The training script for `conformer_ctc` comes from the LibriSpeech `conformer_ctc` recipe in icefall. + +A lot of the scripts for data processing are from the first-gen Kaldi and the ESPNet project, tailored by myself to incorporate with Lhotse and Icefall. + +Some of the scripts for text normalization are from stale pull requests of [Piotr Żelasko](https://github.com/pzelasko) and [Nagendra Goel](https://github.com/ngoel17). + +The `sclite_scoring.py` is from the GigaSpeech recipe for post processing and glm-like scoring, which is definitely not an elegant stuff to do. diff --git a/egs/swbd/ASR/RESULTS.md b/egs/swbd/ASR/RESULTS.md new file mode 100644 index 000000000..f3a22c444 --- /dev/null +++ b/egs/swbd/ASR/RESULTS.md @@ -0,0 +1,113 @@ +## Results +### Switchboard BPE training results (Conformer-CTC) + +#### 2023-09-04 + +The best WER, as of 2023-09-04, for the Switchboard is below + +Results using attention decoder are given as: + +| | eval2000-swbd | eval2000-callhome | eval2000-avg | +|--------------------------------|-----------------|---------------------|--------------| +| `conformer_ctc` | 9.48 | 17.73 | 13.67 | + +Decoding results and models can be found here: +https://huggingface.co/zrjin/icefall-asr-swbd-conformer-ctc-2023-8-26 +#### 2023-06-27 + +The best WER, as of 2023-06-27, for the Switchboard is below + +Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring: + +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 30.80 | 32.29 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: + +##### eval2000 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.9 | 1.1 | + +##### rt03 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.9 | 1.9 | + +To reproduce the above result, use the following commands for training: + +```bash +cd egs/swbd/ASR +./prepare.sh +export CUDA_VISIBLE_DEVICES="0,1" +./conformer_ctc/train.py \ + --max-duration 120 \ + --num-workers 8 \ + --enable-musan False \ + --world-size 2 \ + --num-epochs 100 +``` + +and the following command for decoding: + +```bash +./conformer_ctc/decode.py \ + --epoch 99 \ + --avg 10 \ + --max-duration 50 +``` + +#### 2023-06-26 + +The best WER, as of 2023-06-26, for the Switchboard is below + +Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring: + +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 33.37 | 35.06 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: + +##### eval2000 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.3 | 2.5 | + +##### rt03 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.7 | 1.3 | + +To reproduce the above result, use the following commands for training: + +```bash +cd egs/swbd/ASR +./prepare.sh +export CUDA_VISIBLE_DEVICES="0,1" +./conformer_ctc/train.py \ + --max-duration 120 \ + --num-workers 8 \ + --enable-musan False \ + --world-size 2 \ +``` + +and the following command for decoding: + +```bash +./conformer_ctc/decode.py \ + --epoch 55 \ + --avg 1 \ + --max-duration 50 +``` + +For your reference, the nbest oracle WERs are: + +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 25.64 | 26.84 | diff --git a/egs/swbd/ASR/conformer_ctc/__init__.py b/egs/swbd/ASR/conformer_ctc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py new file mode 100644 index 000000000..59d73c660 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -0,0 +1,416 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Modified by Zengrui Jin for the SwitchBoard corpus +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class SwitchBoardAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train dataloader, + but there can be multiple test dataloaders (e.g. SwitchBoard rt03 + and eval2000). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=50, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + buffer_size=50000, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_all_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" + ).subset(last=166844) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get dev cuts") + return load_manifest_lazy( + self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" + ).subset(first=300) + + @lru_cache() + def test_eval2000_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get eval2000 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "eval2000" / "eval2000_cuts_all.jsonl.gz" + ) + + @lru_cache() + def test_rt03_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get rt03 cuts") + return load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_rt03.jsonl.gz") diff --git a/egs/swbd/ASR/conformer_ctc/conformer.py b/egs/swbd/ASR/conformer_ctc/conformer.py new file mode 120000 index 000000000..d1f4209d7 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/conformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py new file mode 100755 index 000000000..2bbade374 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -0,0 +1,853 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# Modified by Zengrui Jin for the SwitchBoard corpus +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import SwitchBoardAsrDataModule +from conformer import Conformer + +from sclite_scoring import asr_text_post_processing + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_rnn_lm, + rescore_with_whole_lattice, +) +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import ( + AttributeDict, + get_texts, + load_averaged_model, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=98, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=55, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume + you have trained an RNN LM using ./rnn_lm/train.py + - (7) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, + ) + + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.method == "attention-decoder": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + nbest_scale=params.nbest_scale, + ) + elif params.method == "rnn-lm": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_rnn_lm( + lattice=rescored_lattice, + num_paths=params.num_paths, + rnn_lm_model=rnn_lm_model, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + blank_id=0, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert len(results) > 0, "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]], +): + if params.method in ("attention-decoder", "rnn-lm"): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + if test_set_name == "test-eval2000": + subsets = {"callhome": "en_", "swbd": "sw_", "avg": "*"} + elif test_set_name == "test-rt03": + subsets = {"fisher": "fsh_", "swbd": "sw_", "avg": "*"} + else: + raise NotImplementedError(f"No implementation for testset {test_set_name}") + for subset, prefix in subsets.items(): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{subset}-{key}.txt" + results = post_processing(results) + results = ( + sorted(list(filter(lambda x: x[0].startswith(prefix), results))) + if subset != "avg" + else sorted(results) + ) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{subset}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, + f"{test_set_name}-{subset}-{key}", + results, + enable_log=enable_log, + sclite_mode=True, + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}-{subset}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}-{}, WER of different settings are:\n".format( + test_set_name, subset + ) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + SwitchBoardAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + params.num_classes = num_classes + params.sos_id = sos_id + params.eos_id = eos_id + + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + model = load_averaged_model( + params.exp_dir, model, params.epoch, params.avg, device + ) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + rnn_lm_model = None + if params.method == "rnn-lm": + rnn_lm_model = RnnLmModel( + vocab_size=params.num_classes, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + if params.rnn_lm_avg == 1: + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + else: + rnn_lm_model = load_averaged_model( + params.rnn_lm_exp_dir, + rnn_lm_model, + params.rnn_lm_epoch, + params.rnn_lm_avg, + device, + ) + rnn_lm_model.eval() + + # we need cut ids to display recognition results. + args.return_cuts = True + switchboard = SwitchBoardAsrDataModule(args) + + test_eval2000_cuts = switchboard.test_eval2000_cuts().trim_to_supervisions( + keep_all_channels=True + ) + # test_rt03_cuts = switchboard.test_rt03_cuts().trim_to_supervisions( + # keep_all_channels=True + # ) + + test_eval2000_dl = switchboard.test_dataloaders(test_eval2000_cuts) + # test_rt03_dl = switchboard.test_dataloaders(test_rt03_cuts) + + # test_sets = ["test-eval2000", "test-rt03"] + # test_dl = [test_eval2000_dl, test_rt03_dl] + test_sets = ["test-eval2000"] + test_dl = [test_eval2000_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/export.py b/egs/swbd/ASR/conformer_ctc/export.py new file mode 100755 index 000000000..1bb6277ad --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/export.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from conformer import Conformer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=98, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=55, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""It contains language related input files such as "lexicon.txt" + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + } + ) + return params + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/conformer_ctc/label_smoothing.py b/egs/swbd/ASR/conformer_ctc/label_smoothing.py new file mode 120000 index 000000000..e9d239fff --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/pretrained.py b/egs/swbd/ASR/conformer_ctc/pretrained.py new file mode 120000 index 000000000..526bc9678 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/pretrained.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/sclite_scoring.py b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py new file mode 100755 index 000000000..0383c4d71 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# Copyright 2021 Jiayu Du +# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", + "MHM", + "HUM", + "AW", + "OH", + "HMM", + "UMM", +] +unk_tags = ["", ""] +switchboard_garbage_utterance_tags = [ + "[LAUGHTER]", + "[NOISE]", + "[VOCALIZED-NOISE]", + "[SILENCE]", +] +non_scoring_words = ( + conversational_filler + unk_tags + switchboard_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove non-scoring words from evaluation + remaining_words = [] + text_split = text.split() + word_to_skip = 0 + for idx, word in enumerate(text_split): + if word_to_skip > 0: + word_to_skip -= 1 + continue + if word in non_scoring_words: + continue + elif word == "CANCELLED": + remaining_words.append("CANCELED") + continue + elif word == "AIRFLOW": + remaining_words.append("AIR") + remaining_words.append("FLOW") + continue + elif word == "PHD": + remaining_words.append("P") + remaining_words.append("H") + remaining_words.append("D") + continue + elif word == "UCLA": + remaining_words.append("U") + remaining_words.append("C") + remaining_words.append("L") + remaining_words.append("A") + continue + elif word == "ONTO": + remaining_words.append("ON") + remaining_words.append("TO") + continue + elif word == "DAY": + try: + if text_split[idx + 1] == "CARE": + remaining_words.append("DAYCARE") + word_to_skip = 1 + except: + remaining_words.append(word) + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="This script evaluates GigaSpeech ASR result via" + "SCTK's tool sclite" + ) + parser.add_argument( + "ref", + type=str, + help="sclite's standard transcription(trn) reference file", + ) + parser.add_argument( + "hyp", + type=str, + help="sclite's standard transcription(trn) hypothesis file", + ) + parser.add_argument( + "work_dir", + type=str, + help="working dir", + ) + args = parser.parse_args() + + if not os.path.isdir(args.work_dir): + os.mkdir(args.work_dir) + + REF = os.path.join(args.work_dir, "REF") + HYP = os.path.join(args.work_dir, "HYP") + RESULT = os.path.join(args.work_dir, "RESULT") + + for io in [(args.ref, REF), (args.hyp, HYP)]: + with open(io[0], "r", encoding="utf8") as fi: + with open(io[1], "w+", encoding="utf8") as fo: + for line in fi: + line = line.strip() + if line: + cols = line.split() + text = asr_text_post_processing(" ".join(cols[0:-1])) + uttid_field = cols[-1] + print(f"{text} {uttid_field}", file=fo) + + # GigaSpeech's uttid comforms to swb + os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}") diff --git a/egs/swbd/ASR/conformer_ctc/subsampling.py b/egs/swbd/ASR/conformer_ctc/subsampling.py new file mode 120000 index 000000000..16354dc73 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py b/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py new file mode 100755 index 000000000..5d4438fd1 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distutils.version import LooseVersion + +import torch +from label_smoothing import LabelSmoothingLoss + +torch_ver = LooseVersion(torch.__version__) + + +def test_with_torch_label_smoothing_loss(): + if torch_ver < LooseVersion("1.10.0"): + print(f"Current torch version: {torch_ver}") + print("Please use torch >= 1.10 to run this test - skipping") + return + torch.manual_seed(20211105) + x = torch.rand(20, 30, 5000) + tgt = torch.randint(low=-1, high=x.size(-1), size=x.shape[:2]) + for reduction in ["none", "sum", "mean"]: + custom_loss_func = LabelSmoothingLoss( + ignore_index=-1, label_smoothing=0.1, reduction=reduction + ) + custom_loss = custom_loss_func(x, tgt) + + torch_loss_func = torch.nn.CrossEntropyLoss( + ignore_index=-1, reduction=reduction, label_smoothing=0.1 + ) + torch_loss = torch_loss_func(x.reshape(-1, x.size(-1)), tgt.reshape(-1)) + assert torch.allclose(custom_loss, torch_loss) + + +def main(): + test_with_torch_label_smoothing_loss() + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/test_subsampling.py b/egs/swbd/ASR/conformer_ctc/test_subsampling.py new file mode 100755 index 000000000..81fa234dd --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_subsampling.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from subsampling import Conv2dSubsampling, VggSubsampling + + +def test_conv2d_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = Conv2dSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim + + +def test_vgg_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = VggSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim diff --git a/egs/swbd/ASR/conformer_ctc/test_transformer.py b/egs/swbd/ASR/conformer_ctc/test_transformer.py new file mode 120000 index 000000000..8b0990ec6 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/test_transformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/train.py b/egs/swbd/ASR/conformer_ctc/train.py new file mode 100755 index 000000000..7f1eebbcf --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/train.py @@ -0,0 +1,814 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# Modified by Zengrui Jin for the SwitchBoard corpus +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./conformer_ctc/train.py \ + --exp-dir ./conformer_ctc/exp \ + --world-size 4 \ + --max-duration 200 \ + --num-epochs 20 +""" + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import SwitchBoardAsrDataModule +from conformer import Conformer +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.graph_compiler import CtcTrainingGraphCompiler +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=98, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Normalization for the input features, can be a + boolean indicating whether to do batch + normalization, or a float which means just scaling + the input features with this float value. + If given a float value, we will remove batchnorm + layer in `ConvolutionModule` as well. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "weight_decay": 1e-6, + "warm_step": 80000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + elif isinstance(graph_compiler, CtcTrainingGraphCompiler): + # Works with a phone lexicon + decoding_graph = graph_compiler.compile(texts) + else: + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + if params.att_rate != 0.0: + with torch.set_grad_enabled(is_training): + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + att_loss = torch.tensor([0]) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = supervisions["num_frames"].sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ) + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + if "lang_bpe" in str(params.lang_dir): + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + elif "lang_phone" in str(params.lang_dir): + assert params.att_rate == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. Set --att-rate=0 " + "for pure CTC training when using a phone-based lang dir." + ) + assert params.num_decoder_layers == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. " + "Set --num-decoder-layers=0 for pure CTC training when using " + "a phone-based lang dir." + ) + graph_compiler = CtcTrainingGraphCompiler( + lexicon, + device=device, + ) + # Manually add the sos/eos ID with their default values + # from the BPE recipe which we're adapting here. + graph_compiler.sos_id = 1 + graph_compiler.eos_id = 1 + else: + raise ValueError( + f"Unsupported type of lang dir (we expected it to have " + f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + switchboard = SwitchBoardAsrDataModule(args) + + train_cuts = switchboard.train_all_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = switchboard.train_dataloaders(train_cuts) + + valid_cuts = switchboard.dev_cuts() + valid_dl = switchboard.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + SwitchBoardAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/transformer.py b/egs/swbd/ASR/conformer_ctc/transformer.py new file mode 120000 index 000000000..1c3f43fcf --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/transformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compile_hlg.py b/egs/swbd/ASR/local/compile_hlg.py new file mode 120000 index 000000000..471aa7fb4 --- /dev/null +++ b/egs/swbd/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compile_lg.py b/egs/swbd/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/swbd/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compute_fbank_eval2000.py b/egs/swbd/ASR/local/compute_fbank_eval2000.py new file mode 100755 index 000000000..d446e8ff3 --- /dev/null +++ b/egs/swbd/ASR/local/compute_fbank_eval2000.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# Modified 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the SwitchBoard dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + return parser.parse_args() + + +def compute_fbank_switchboard( + dir_name: str, + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, +): + src_dir = Path(f"data/manifests/{dir_name}") + output_dir = Path(f"data/fbank/{dir_name}") + num_jobs = min(1, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + if dataset is None: + dataset_parts = ("all",) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = dir_name + suffix = "jsonl.gz" + manifests = { + "eval2000": "data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz", + } + assert manifests is not None + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=16000)) + + with get_executor() as ex: # Initialize the executor only once. + partition = "all" + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + print(cuts_filename) + if (output_dir / cuts_filename).is_file(): + logging.info(f"{prefix} already exists - skipping.") + return + logging.info(f"Processing {prefix}") + cut_set = CutSet.from_file(manifests[prefix]).resample(16000) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_switchboard( + dir_name="eval2000", + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/swbd/ASR/local/compute_fbank_swbd.py b/egs/swbd/ASR/local/compute_fbank_swbd.py new file mode 100755 index 000000000..dd82220c0 --- /dev/null +++ b/egs/swbd/ASR/local/compute_fbank_swbd.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# Modified 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the SwitchBoard dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + parser.add_argument( + "--split-index", + type=int, + required=True, + ) + + return parser.parse_args() + + +def compute_fbank_switchboard( + dir_name: str, + split_index: int, + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, +): + src_dir = Path(f"data/manifests/{dir_name}") + output_dir = Path(f"data/fbank/{dir_name}_split16") + num_jobs = min(1, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + if dataset is None: + dataset_parts = ("all",) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = dir_name + suffix = "jsonl.gz" + split_dir = Path("data/manifests/swbd_split16/") + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=16000)) + + with get_executor() as ex: # Initialize the executor only once. + partition = "all" + cuts_filename = ( + f"{prefix}_cuts_{partition}.{str(split_index).zfill(2)}.{suffix}" + ) + print(cuts_filename) + if (output_dir / cuts_filename).is_file(): + logging.info(f"{prefix} already exists - skipping.") + return + logging.info(f"Processing {prefix}") + cut_set = ( + CutSet.from_file( + split_dir + / f"swbd_train_all_trimmed.{str(split_index).zfill(2)}.jsonl.gz" + ) + .resample(16000) + .to_eager() + .filter(lambda c: c.duration > 2.0) + ) + + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + if perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}_{str(split_index).zfill(2)}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, + min_duration=None, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + + compute_fbank_switchboard( + dir_name="swbd", + split_index=args.split_index, + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py b/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py new file mode 100755 index 000000000..a8d5117c9 --- /dev/null +++ b/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +""" +Convert a transcript file containing words to a corpus file containing tokens +for LM training with the help of a lexicon. + +If the lexicon contains phones, the resulting LM will be a phone LM; If the +lexicon contains word pieces, the resulting LM will be a word piece LM. + +If a word has multiple pronunciations, the one that appears first in the lexicon +is kept; others are removed. + +If the input transcript is: + + hello zoo world hello + world zoo + foo zoo world hellO + +and if the lexicon is + + SPN + hello h e l l o 2 + hello h e l l o + world w o r l d + zoo z o o + +Then the output is + + h e l l o 2 z o o w o r l d h e l l o 2 + w o r l d z o o + SPN z o o w o r l d SPN +""" + +import argparse +from pathlib import Path +from typing import Dict, List + +from generate_unique_lexicon import filter_multiple_pronunications + +from icefall.lexicon import read_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transcript", + type=str, + help="The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words.", + ) + parser.add_argument("--lexicon", type=str, help="The input lexicon file.") + parser.add_argument("--oov", type=str, default="", help="The OOV word.") + + return parser.parse_args() + + +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: + """ + Args: + lexicon: + A dict containing pronunciations. Its keys are words and values + are pronunciations (i.e., tokens). + line: + A line of transcript consisting of space(s) separated words. + oov_token: + The pronunciation of the oov word if a word in `line` is not present + in the lexicon. + Returns: + Return None. + """ + s = "" + words = line.strip().split() + for i, w in enumerate(words): + tokens = lexicon.get(w, oov_token) + s += " ".join(tokens) + s += " " + print(s.strip()) + + +def main(): + args = get_args() + assert Path(args.lexicon).is_file() + assert Path(args.transcript).is_file() + assert len(args.oov) > 0 + + # Only the first pronunciation of a word is kept + lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) + + lexicon = dict(lexicon) + + assert args.oov in lexicon + + oov_token = lexicon[args.oov] + + with open(args.transcript) as f: + for line in f: + process_line(lexicon=lexicon, line=line, oov_token=oov_token) + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/dict.patch b/egs/swbd/ASR/local/dict.patch new file mode 100644 index 000000000..12c63d612 --- /dev/null +++ b/egs/swbd/ASR/local/dict.patch @@ -0,0 +1,380 @@ +1d0 +< file: $SWB/data/dictionary/sw-ms98-dict.text +8645a8646 +> uh-hum ah m hh ah m +9006c9007 +< April ey p r ih l +--- +> April ey p r ax l +9144d9144 +< B ay zh aa n iy z +9261c9261 +< Battle b ae t el +--- +> Battle b ae t ax l +10014a10015 +> Chevy sh eh v iy +10211a10213 +> Colorado k ao l ax r aa d ow +10212a10215 +> Colorado' k ao l ax r aa d ow z +10370c10373 +< Creek k r ih k +--- +> Creek k r iy k +10889a10893 +> Eleven ax l eh v ih n +10951c10955 +< Erie ih r iy +--- +> Erie iy r iy +11183c11187 +< Forever f ax r eh v er +--- +> Forever f er eh v er +11231a11236 +> Friday f r ay d iy +11744a11750 +> History hh ih s t r iy +12004a12011,12012 +> Israel ih z r ih l +> Israel's ih z r ih l z +12573a12582 +> Lincoln l ih ng k ih n +12574a12584 +> Lincolns l ih ng k ih n z +13268c13278 +< NAACP eh ey ey s iy p iy +--- +> NAACP eh n ey ey s iy p iy +13286c13296 +< NIT eh ay t iy +--- +> NIT eh n ay t iy +13292c13302 +< NTSC eh t iy eh s s iy +--- +> NTSC eh n t iy eh s s iy +14058a14069 +> Quarter k ow r t er +14059a14071 +> Quarterback k ow r t er b ae k +14060a14073 +> Quarters k ow r t er z +14569a14583 +> Science s ay n s +15087a15102 +> Sunday s ah n d iy +15088a15104 +> Sunday's s ah n d iy z +15089a15106 +> Sundays s ah n d iy z +15290,15291c15307,15308 +< Texan t eh k sh ih n +< Texan's t eh k sh ih n s +--- +> Texan t eh k s ih n +> Texan's t eh k s ih n s +15335a15353 +> Thousands th aw z ih n z +15739c15757 +< Waco w ae k ow +--- +> Waco w ey k ow +15841a15860 +> Weekends w iy k eh n z +16782a16802 +> acceptable eh k s eh p ax b ax l +16833a16854 +> accounting ax k aw n ih ng +16948a16970 +> address ax d r eh s +17281a17304 +> already aa r d iy +17315a17339 +> am m +17709a17734 +> asked ae s t +17847a17873 +> attorney ih t er n iy +17919a17946 +> autopilot ao t ow p ay l ih t +17960a17988 +> awfully ao f l iy +18221a18250 +> basketball b ae s k ax b ao l +18222a18252 +> basketball's b ae s k ax b ao l z +18302a18333 +> become b ah k ah m +18303a18335 +> becomes b iy k ah m z +18344a18377 +> began b ax g en n +18817c18850 +< bottle b aa t el +--- +> bottle b aa t ax l +19332,19333c19365,19367 +< camera's k ae m ax r ax z +< cameras k ae m ax r ax z +--- +> camera k ae m r ax +> camera's k ae m r ax z +> cameras k ae m r ax z +19411a19446 +> capital k ae p ax l +19505a19541 +> carrying k ae r ih ng +20316a20353,20354 +> combination k aa m ih n ey sh ih n +> combinations k aa m ih n ey sh ih n z +20831a20870 +> contracts k aa n t r ae k s +21010a21050 +> costs k ao s +21062a21103 +> county k aw n iy +21371a21413 +> cultural k ao l ch ax r ax l +21372a21415 +> culturally k ao l ch ax r ax l iy +21373a21417 +> culture k ao l ch er +21375a21420 +> cultures k ao l ch er z +21543a21589 +> data d ey t ax +22097a22144 +> differently d ih f ax r ih n t l iy +22972a23020 +> effects ax f eh k t s +23016a23065 +> election ax l eh k sh ih n +23018a23068 +> elections ax l eh k sh ih n z +23052a23103 +> eleven ax l eh v ih n +23242a23294 +> enjoyable ae n jh oy ax b ax l +23248a23301 +> enjoys ae n jh oy z +23293a23347 +> entire ih n t ay r +23295a23350,23351 +> entirely ih n t ay r l iy +> entirety ih n t ay r t iy +23745a23802 +> extra eh k s t er +23818a23876 +> facts f ae k s +24508c24566 +< forever f ax r eh v er +--- +> forever f er eh v er +24514c24572 +< forget f ow r g eh t +--- +> forget f er r g eh t +24521a24580 +> forgot f er r g aa t +24522a24582 +> forgotten f er r g aa t ax n +24563a24624 +> forward f ow er d +24680a24742 +> frightening f r ay t n ih ng +24742a24805 +> full-time f ax l t ay m +24862a24926 +> garage g r aa jh +25218a25283 +> grandmother g r ae m ah dh er +25790a25856 +> heavily hh eh v ax l iy +25949a26016 +> history hh ih s t r iy +26038a26106 +> honestly aa n ax s t l iy +26039a26108 +> honesty aa n ax s t iy +26099a26169 +> horror hh ow r +26155a26226 +> houses hh aw z ih z +26184c26255 +< huh-uh hh ah hh ah +--- +> huh-uh ah hh ah +26189c26260 +< hum-um hh m hh m +--- +> hum-um ah m hh ah m +26236a26308 +> hunting hh ah n ih ng +26307a26380,26381 +> ideal ay d iy l +> idealist ay d iy l ih s t +26369a26444 +> imagine m ae jh ih n +26628a26704 +> individuals ih n d ih v ih jh ax l z +26968a27045 +> interest ih n t r ih s t +27184a27262 +> it'd ih d +27702a27781 +> lead l iy d +28378a28458 +> mandatory m ae n d ih t ow r iy +28885a28966 +> minute m ih n ih t +29167a29249 +> mountains m aw t n z +29317a29400 +> mysteries m ih s t r iy z +29318a29402 +> mystery m ih s t r iy +29470a29555 +> nervous n er v ih s +29578,29580c29663,29665 +< nobody n ow b aa d iy +< nobody'll n ow b aa d iy l +< nobody's n ow b aa d iy z +--- +> nobody n ow b ah d iy +> nobody'll n ow b ah d iy l +> nobody's n ow b ah d iy z +29712a29798 +> nuclear n uw k l iy r +29938a30025 +> onto aa n t ax +30051a30139 +> originally ax r ih jh ax l iy +30507a30596 +> particularly p er t ih k y ax l iy +30755a30845 +> perfectly p er f ih k l iy +30820a30911 +> personally p er s n ax l iy +30915a31007 +> physically f ih z ih k l iy +30986a31079 +> pilot p ay l ih t +30987a31081 +> pilot's p ay l ih t s +31227a31322 +> police p l iy s +31513a31609 +> prefer p er f er +31553a31650 +> prepare p r ax p ey r +31578a31676 +> prescription p er s k r ih p sh ih n +31579a31678 +> prescriptions p er s k r ih p sh ih n z +31770a31870 +> products p r aa d ax k s +31821a31922 +> projects p r aa jh eh k s +31908a32010 +> protect p er t eh k t +31909a32012 +> protected p er t eh k t ih d +31911a32015 +> protection p er t eh k sh ih n +31914a32019 +> protection p er t eh k t ih v +32149a32255 +> quarter k ow r t er +32414a32521 +> read r iy d +32785a32893 +> rehabilitation r iy ax b ih l ih t ey sh ih n +33150a33259 +> resource r ih s ow r s +33151a33261 +> resources r iy s ow r s ih z +33539c33649 +< roots r uh t s +--- +> roots r uw t s +33929a34040 +> science s ay n s +34315a34427 +> seventy s eh v ih n iy +34319,34320c34431,34432 +< severe s ax v iy r +< severely s ax v iy r l iy +--- +> severe s ih v iy r +> severely s ih v iy r l iy +35060a35173 +> software s ao f w ey r +35083a35197 +> solid s ao l ih d +35084a35199 +> solidly s ao l ih d l iy +35750a35866 +> stood s t ih d +35854a35971 +> strictly s t r ih k l iy +35889c36006 +< stronger s t r ao ng er +--- +> stronger s t r ao ng g er +36192a36310,36311 +> supposed s p ow z +> supposed s p ow s +36510a36630 +> tastes t ey s +36856a36977 +> thoroughly th er r l iy +36866a36988 +> thousands th aw z ih n z +37081c37203 +< toots t uh t s +--- +> toots t uw t s +37157a37280 +> toward t w ow r d +37158a37282 +> towards t w ow r d z +37564a37689 +> twenties t w eh n iy z +37565a37691 +> twentieth t w eh n iy ih th +37637a37764 +> unacceptable ah n ae k s eh p ax b ax l +37728a37856 +> understand ah n d er s t ae n +37860a37989 +> unless ih n l eh s +38040a38170 +> use y uw z +38049a38180 +> uses y uw z ih z +38125a38257 +> various v ah r iy ih s +38202a38335 +> versus v er s ih z +38381c38514 +< wacko w ae k ow +--- +> wacko w ey k ow +38455c38588 +< wanna w aa n ax +--- +> wanna w ah n ax +38675c38808 +< whatnot w ah t n aa t +--- +> whatnot w aa t n aa t +38676a38810 +> whatsoever w aa t s ow eh v er +38890c39024 +< wok w aa k +--- +> wok w ao k +38910a39045 +> wondering w ah n d r ih ng diff --git a/egs/swbd/ASR/local/display_manifest_statistics.py b/egs/swbd/ASR/local/display_manifest_statistics.py new file mode 100755 index 000000000..9aa204863 --- /dev/null +++ b/egs/swbd/ASR/local/display_manifest_statistics.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in transducer/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + # path = "./data/fbank/swbd_cuts_rt03.jsonl.gz" + path = "./data/fbank/eval2000/eval2000_cuts_all.jsonl.gz" + # path = "./data/fbank/swbd_cuts_all.jsonl.gz" + + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +Training Cut statistics: +╒═══════════════════════════╤═══════════╕ +│ Cuts count: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Total duration (hh:mm:ss) │ 281:01:26 │ +├───────────────────────────┼───────────┤ +│ mean │ 6.0 │ +├───────────────────────────┼───────────┤ +│ std │ 3.3 │ +├───────────────────────────┼───────────┤ +│ min │ 2.0 │ +├───────────────────────────┼───────────┤ +│ 25% │ 3.2 │ +├───────────────────────────┼───────────┤ +│ 50% │ 5.2 │ +├───────────────────────────┼───────────┤ +│ 75% │ 8.3 │ +├───────────────────────────┼───────────┤ +│ 99% │ 14.4 │ +├───────────────────────────┼───────────┤ +│ 99.5% │ 14.7 │ +├───────────────────────────┼───────────┤ +│ 99.9% │ 15.0 │ +├───────────────────────────┼───────────┤ +│ max │ 57.5 │ +├───────────────────────────┼───────────┤ +│ Recordings available: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Features available: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Supervisions available: │ 167244 │ +╘═══════════════════════════╧═══════════╛ +Speech duration statistics: +╒══════════════════════════════╤═══════════╤══════════════════════╕ +│ Total speech duration │ 281:01:26 │ 100.00% of recording │ +├──────────────────────────────┼───────────┼──────────────────────┤ +│ Total speaking time duration │ 281:01:26 │ 100.00% of recording │ +├──────────────────────────────┼───────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧═══════════╧══════════════════════╛ + +Eval2000 Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 4473 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 03:37:13 │ +├───────────────────────────┼──────────┤ +│ mean │ 2.9 │ +├───────────────────────────┼──────────┤ +│ std │ 2.6 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 1.2 │ +├───────────────────────────┼──────────┤ +│ 50% │ 2.1 │ +├───────────────────────────┼──────────┤ +│ 75% │ 4.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 12.6 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 13.7 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 14.7 │ +├───────────────────────────┼──────────┤ +│ max │ 15.5 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 4473 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 4473 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 4473 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 03:37:13 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 03:37:13 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +""" diff --git a/egs/swbd/ASR/local/extend_segments.pl b/egs/swbd/ASR/local/extend_segments.pl new file mode 100755 index 000000000..e8b4894d5 --- /dev/null +++ b/egs/swbd/ASR/local/extend_segments.pl @@ -0,0 +1,99 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter + +if (@ARGV != 1 || !($ARGV[0] =~ m/^-?\d+\.?\d*$/ && $ARGV[0] >= 0)) { + print STDERR "Usage: extend_segments.pl time-in-seconds segments.extended \n" . + "e.g. extend_segments.pl 0.25 segments.2\n" . + "This command modifies a segments file, with lines like\n" . + " \n" . + "by extending the beginning and end of each segment by a certain\n" . + "length of time. This script makes sure the output segments do not\n" . + "overlap as a result of this time-extension, and that there are no\n" . + "negative times in the output.\n"; + exit 1; +} + +$extend = $ARGV[0]; + +@all_lines = (); + +while () { + chop; + @A = split(" ", $_); + if (@A != 4) { + die "invalid line in segments file: $_"; + } + $line = @all_lines; # current number of lines. + ($utt_id, $reco_id, $start_time, $end_time) = @A; + + push @all_lines, [ $utt_id, $reco_id, $start_time, $end_time ]; # anonymous array. + if (! defined $lines_for_reco{$reco_id}) { + $lines_for_reco{$reco_id} = [ ]; # push new anonymous array. + } + push @{$lines_for_reco{$reco_id}}, $line; +} + +foreach $reco_id (keys %lines_for_reco) { + $ref = $lines_for_reco{$reco_id}; + @line_numbers = sort { ${$all_lines[$a]}[2] <=> ${$all_lines[$b]}[2] } @$ref; + + + { + # handle start of earliest segment as a special case. + $l0 = $line_numbers[0]; + $tstart = ${$all_lines[$l0]}[2] - $extend; + if ($tstart < 0.0) { $tstart = 0.0; } + ${$all_lines[$l0]}[2] = $tstart; + } + { + # handle end of latest segment as a special case. + $lN = $line_numbers[$#line_numbers]; + $tend = ${$all_lines[$lN]}[3] + $extend; + ${$all_lines[$lN]}[3] = $tend; + } + for ($i = 0; $i < $#line_numbers; $i++) { + $ln = $line_numbers[$i]; + $ln1 = $line_numbers[$i+1]; + $tend = ${$all_lines[$ln]}[3]; # end of earlier segment. + $tstart = ${$all_lines[$ln1]}[2]; # start of later segment. + if ($tend > $tstart) { + $utt1 = ${$all_lines[$ln]}[0]; + $utt2 = ${$all_lines[$ln1]}[0]; + print STDERR "Warning: for utterances $utt1 and $utt2, segments " . + "already overlap; leaving these times unchanged.\n"; + } else { + $my_extend = $extend; + $max_extend = 0.5 * ($tstart - $tend); + if ($my_extend > $max_extend) { $my_extend = $max_extend; } + $tend += $my_extend; + $tstart -= $my_extend; + ${$all_lines[$ln]}[3] = $tend; + ${$all_lines[$ln1]}[2] = $tstart; + } + } +} + +# leave the numbering of the lines unchanged. +for ($l = 0; $l < @all_lines; $l++) { + $ref = $all_lines[$l]; + ($utt_id, $reco_id, $start_time, $end_time) = @$ref; + printf("%s %s %.2f %.2f\n", $utt_id, $reco_id, $start_time, $end_time); +} + +__END__ + +# testing below. + +# ( echo a1 A 0 1; echo a2 A 3 4; echo b1 B 0 1; echo b2 B 2 3 ) | local/extend_segments.pl 1.0 +a1 A 0.00 2.00 +a2 A 2.00 5.00 +b1 B 0.00 1.50 +b2 B 1.50 4.00 +# ( echo a1 A 0 2; echo a2 A 1 3 ) | local/extend_segments.pl 1.0 +Warning: for utterances a1 and a2, segments already overlap; leaving these times unchanged. +a1 A 0.00 2.00 +a2 A 1.00 4.00 +# ( echo a1 A 0 2; echo a2 A 5 6; echo a3 A 3 4 ) | local/extend_segments.pl 1.0 +a1 A 0.00 2.50 +a2 A 4.50 7.00 +a3 A 2.50 4.50 diff --git a/egs/swbd/ASR/local/filter_cuts.py b/egs/swbd/ASR/local/filter_cuts.py new file mode 100755 index 000000000..fbcc9e24a --- /dev/null +++ b/egs/swbd/ASR/local/filter_cuts.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script removes short and long utterances from a cutset. + +Caution: + You may need to tune the thresholds for your own dataset. + +Usage example: + + python3 ./local/filter_cuts.py \ + --bpe-model data/lang_bpe_500/bpe.model \ + --in-cuts data/fbank/librispeech_cuts_test-clean.jsonl.gz \ + --out-cuts data/fbank-filtered/librispeech_cuts_test-clean.jsonl.gz +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=Path, + help="Path to the bpe.model", + ) + + parser.add_argument( + "--in-cuts", + type=Path, + help="Path to the input cutset", + ) + + parser.add_argument( + "--out-cuts", + type=Path, + help="Path to the output cutset", + ) + + return parser.parse_args() + + +def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): + total = 0 # number of total utterances before removal + removed = 0 # number of removed utterances + + def remove_short_and_long_utterances(c: Cut): + """Return False to exclude the input cut""" + nonlocal removed, total + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ./display_manifest_statistics.py + # + # You should use ./display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + total += 1 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + removed += 1 + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./pruned_transducer_stateless2/conformer.py, the + # conv module uses the following expression + # for subsampling + if c.num_frames is None: + num_frames = c.duration * 100 # approximate + else: + num_frames = c.num_frames + + T = ((num_frames - 1) // 2 - 1) // 2 + # Note: for ./lstm_transducer_stateless/lstm.py, the formula is + # T = ((num_frames - 3) // 2 - 1) // 2 + + # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is + # T = ((num_frames - 7) // 2 + 1) // 2 + + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + removed += 1 + return False + + return True + + # We use to_eager() here so that we can print out the value of total + # and removed below. + ans = cut_set.filter(remove_short_and_long_utterances).to_eager() + ratio = removed / total * 100 + logging.info( + f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." + ) + return ans + + +def main(): + args = get_args() + logging.info(vars(args)) + + if args.out_cuts.is_file(): + logging.info(f"{args.out_cuts} already exists - skipping") + return + + assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist" + assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist" + + sp = spm.SentencePieceProcessor() + sp.load(str(args.bpe_model)) + + cut_set = load_manifest_lazy(args.in_cuts) + assert isinstance(cut_set, CutSet) + + cut_set = filter_cuts(cut_set, sp) + logging.info(f"Saving to {args.out_cuts}") + args.out_cuts.parent.mkdir(parents=True, exist_ok=True) + cut_set.to_file(args.out_cuts) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/swbd/ASR/local/filter_empty_text.py b/egs/swbd/ASR/local/filter_empty_text.py new file mode 100755 index 000000000..6b3316800 --- /dev/null +++ b/egs/swbd/ASR/local/filter_empty_text.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path +import logging +from typing import List + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--kaldi-data-dir", + type=Path, + required=True, + help="Path to the kaldi data dir", + ) + + return parser.parse_args() + + +def load_segments(path: Path): + segments = {} + with open(path, "r") as f: + lines = f.readlines() + for line in lines: + line = line.strip() + utt_id, rec_id, start, end = line.split() + segments[utt_id] = line + return segments + + +def filter_text(path: Path): + with open(path, "r") as f: + lines = f.readlines() + return list(filter(lambda x: len(x.strip().split()) > 1, lines)) + + +def write_segments(path: Path, texts: List[str]): + with open(path, "w") as f: + f.writelines(texts) + + +def main(): + args = get_args() + orig_text_dict = filter_text(args.kaldi_data_dir / "text") + write_segments(args.kaldi_data_dir / "text", orig_text_dict) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() + + logging.info("Empty lines filtered") diff --git a/egs/swbd/ASR/local/format_acronyms_dict.py b/egs/swbd/ASR/local/format_acronyms_dict.py new file mode 100755 index 000000000..fa598dd03 --- /dev/null +++ b/egs/swbd/ASR/local/format_acronyms_dict.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +# Copyright 2015 Minhua Wu +# Apache 2.0 + +# convert acronyms in swbd dict to fisher convention +# IBM to i._b._m. +# BBC to b._b._c. +# BBCs to b._b._c.s +# BBC's to b._b._c.'s + +import argparse +import re + +__author__ = "Minhua Wu" + +parser = argparse.ArgumentParser(description="format acronyms to a._b._c.") +parser.add_argument("-i", "--input", help="Input lexicon", required=True) +parser.add_argument("-o", "--output", help="Output lexicon", required=True) +parser.add_argument( + "-L", "--Letter", help="Input single letter pronunciation", required=True +) +parser.add_argument("-M", "--Map", help="Output acronyms mapping", required=True) +args = parser.parse_args() + + +fin_lex = open(args.input, "r") +fin_Letter = open(args.Letter, "r") +fout_lex = open(args.output, "w") +fout_map = open(args.Map, "w") + +# Initialise single letter dictionary +dict_letter = {} +for single_letter_lex in fin_Letter: + items = single_letter_lex.split() + dict_letter[items[0]] = single_letter_lex[len(items[0]) + 1 :].strip() +fin_Letter.close() +# print dict_letter + +for lex in fin_lex: + items = lex.split() + word = items[0] + lexicon = lex[len(items[0]) + 1 :].strip() + # find acronyms from words with only letters and ' + pre_match = re.match(r"^[A-Za-z]+$|^[A-Za-z]+\'s$|^[A-Za-z]+s$", word) + if pre_match: + # find if words in the form of xxx's is acronym + if word[-2:] == "'s" and (lexicon[-1] == "s" or lexicon[-1] == "z"): + actual_word = word[:-2] + actual_lexicon = lexicon[:-2] + acronym_lexicon = "" + for w in actual_word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == actual_lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in actual_word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + actual_word[-1].lower() + ".'s" + acronym_mapped_back = ( + acronym_mapped_back + actual_word[-1].lower() + "'s" + ) + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + + # find if words in the form of xxxs is acronym + elif word[-1] == "s" and (lexicon[-1] == "s" or lexicon[-1] == "z"): + actual_word = word[:-1] + actual_lexicon = lexicon[:-2] + acronym_lexicon = "" + for w in actual_word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == actual_lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in actual_word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + actual_word[-1].lower() + ".s" + acronym_mapped_back = ( + acronym_mapped_back + actual_word[-1].lower() + "'s" + ) + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + + # find if words in the form of xxx (not ended with 's or s) is acronym + elif word.find("'") == -1 and word[-1] != "s": + acronym_lexicon = "" + for w in word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + word[-1].lower() + "." + acronym_mapped_back = acronym_mapped_back + word[-1].lower() + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + else: + fout_lex.write(lex) + + else: + fout_lex.write(lex) diff --git a/egs/swbd/ASR/local/generate_unique_lexicon.py b/egs/swbd/ASR/local/generate_unique_lexicon.py new file mode 100755 index 000000000..3459c2f5a --- /dev/null +++ b/egs/swbd/ASR/local/generate_unique_lexicon.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input a lexicon.txt and output a new lexicon, +in which each word has a unique pronunciation. + +The way to do this is to keep only the first pronunciation of a word +in lexicon.txt. +""" + + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +from icefall.lexicon import read_lexicon, write_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + This file will generate a new file uniq_lexicon.txt + in it. + """, + ) + + return parser.parse_args() + + +def filter_multiple_pronunications( + lexicon: List[Tuple[str, List[str]]] +) -> List[Tuple[str, List[str]]]: + """Remove multiple pronunciations of words from a lexicon. + + If a word has more than one pronunciation in the lexicon, only + the first one is kept, while other pronunciations are removed + from the lexicon. + + Args: + lexicon: + The input lexicon, containing a list of (word, [p1, p2, ..., pn]), + where "p1, p2, ..., pn" are the pronunciations of the "word". + Returns: + Return a new lexicon where each word has a unique pronunciation. + """ + seen = set() + ans = [] + + for word, tokens in lexicon: + if word in seen: + continue + seen.add(word) + ans.append((word, tokens)) + return ans + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + lexicon_filename = lang_dir / "lexicon.txt" + + in_lexicon = read_lexicon(lexicon_filename) + + out_lexicon = filter_multiple_pronunications(in_lexicon) + + write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon) + + logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}") + logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/swbd/ASR/local/map_acronyms_transcripts.py b/egs/swbd/ASR/local/map_acronyms_transcripts.py new file mode 100755 index 000000000..ba02aaec3 --- /dev/null +++ b/egs/swbd/ASR/local/map_acronyms_transcripts.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +# Copyright 2015 Minhua Wu +# Apache 2.0 + +# convert acronyms in swbd transcript to fisher convention +# according to first two columns in the input acronyms mapping + +import argparse +import re + +__author__ = "Minhua Wu" + +parser = argparse.ArgumentParser(description="format acronyms to a._b._c.") +parser.add_argument("-i", "--input", help="Input transcripts", required=True) +parser.add_argument("-o", "--output", help="Output transcripts", required=True) +parser.add_argument("-M", "--Map", help="Input acronyms mapping", required=True) +args = parser.parse_args() + +fin_map = open(args.Map, "r") +dict_acronym = {} +dict_acronym_noi = {} # Mapping of acronyms without I, i +for pair in fin_map: + items = pair.split("\t") + dict_acronym[items[0]] = items[1] + dict_acronym_noi[items[0]] = items[1] +fin_map.close() +del dict_acronym_noi["I"] +del dict_acronym_noi["i"] + + +fin_trans = open(args.input, "r") +fout_trans = open(args.output, "w") +for line in fin_trans: + items = line.split() + L = len(items) + # First pass mapping to map I as part of acronym + for i in range(L): + if items[i] == "I": + x = 0 + while i - 1 - x >= 0 and re.match(r"^[A-Z]$", items[i - 1 - x]): + x += 1 + + y = 0 + while i + 1 + y < L and re.match(r"^[A-Z]$", items[i + 1 + y]): + y += 1 + + if x + y > 0: + for bias in range(-x, y + 1): + items[i + bias] = dict_acronym[items[i + bias]] + + # Second pass mapping (not mapping 'i' and 'I') + for i in range(len(items)): + if items[i] in dict_acronym_noi.keys(): + items[i] = dict_acronym_noi[items[i]] + sentence = " ".join(items[1:]) + fout_trans.write(items[0] + " " + sentence.lower() + "\n") + +fin_trans.close() +fout_trans.close() diff --git a/egs/swbd/ASR/local/normalize_and_filter_supervisions.py b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py new file mode 100755 index 000000000..20ab90caf --- /dev/null +++ b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Nagendra Goel https://github.com/ngoel17) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import re +from typing import Tuple + +from lhotse import SupervisionSegment, SupervisionSet +from lhotse.serialization import load_manifest_lazy_or_eager +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input_sups") + parser.add_argument("output_sups") + return parser.parse_args() + + +# replacement function to convert lowercase letter to uppercase +def to_upper(match_obj): + if match_obj.group() is not None: + return match_obj.group().upper() + + +def insert_groups_and_capitalize_3(match): + return f"{match.group(1)} {match.group(2)} {match.group(3)}".upper() + + +def insert_groups_and_capitalize_2(match): + return f"{match.group(1)} {match.group(2)}".upper() + + +def insert_groups_and_capitalize_1(match): + return f"{match.group(1)}".upper() + + +def insert_groups_and_capitalize_1s(match): + return f"{match.group(1)}".upper() + "'s" + + +class FisherSwbdNormalizer: + """Note: the functions "normalize" and "keep" implement the logic + similar to Kaldi's data prep scripts for Fisher and SWBD: One + notable difference is that we don't change [cough], [lipsmack], + etc. to [noise]. We also don't implement all the edge cases of + normalization from Kaldi (hopefully won't make too much + difference). + """ + + def __init__(self) -> None: + self.remove_regexp_before = re.compile( + r"|".join( + [ + # special symbols + r"\[\[skip.*\]\]", + r"\[skip.*\]", + r"\[pause.*\]", + r"\[silence\]", + r"", + r"", + r"_1", + ] + ) + ) + + # tuples of (pattern, replacement) + # note: Kaldi replaces sighs, coughs, etc with [noise]. + # We don't do that here. + # We also lowercase the text as the first operation. + self.replace_regexps: Tuple[re.Pattern, str] = [ + # SWBD: + # [LAUGHTER-STORY] -> STORY + (re.compile(r"\[laughter-(.*?)\]"), r"\1"), + # [WEA[SONABLE]-/REASONABLE] + (re.compile(r"\[\S+/(\S+)\]"), r"\1"), + # -[ADV]AN[TAGE]- -> AN + (re.compile(r"-?\[.*?\](\w+)\[.*?\]-?"), r"\1-"), + # ABSOLUTE[LY]- -> ABSOLUTE- + (re.compile(r"(\w+)\[.*?\]-?"), r"\1-"), + # [AN]Y- -> Y- + # -[AN]Y- -> Y- + (re.compile(r"-?\[.*?\](\w+)-?"), r"\1-"), + # special tokens + (re.compile(r"\[laugh.*?\]"), r"[laughter]"), + (re.compile(r"\[sigh.*?\]"), r"[sigh]"), + (re.compile(r"\[cough.*?\]"), r"[cough]"), + (re.compile(r"\[mn.*?\]"), r"[vocalized-noise]"), + (re.compile(r"\[breath.*?\]"), r"[breath]"), + (re.compile(r"\[lipsmack.*?\]"), r"[lipsmack]"), + (re.compile(r"\[sneeze.*?\]"), r"[sneeze]"), + # abbreviations + ( + re.compile( + r"(\w)\.(\w)\.(\w)", + ), + insert_groups_and_capitalize_3, + ), + ( + re.compile( + r"(\w)\.(\w)", + ), + insert_groups_and_capitalize_2, + ), + ( + re.compile( + r"([a-h,j-z])\.", + ), + insert_groups_and_capitalize_1, + ), + ( + re.compile( + r"\._", + ), + r" ", + ), + ( + re.compile( + r"_(\w)", + ), + insert_groups_and_capitalize_1, + ), + ( + re.compile( + r"(\w)\.s", + ), + insert_groups_and_capitalize_1s, + ), + ( + re.compile( + r"([A-Z])\'s", + ), + insert_groups_and_capitalize_1s, + ), + ( + re.compile( + r"(\s\w\b|^\w\b)", + ), + insert_groups_and_capitalize_1, + ), + # words between apostrophes + (re.compile(r"'(\S*?)'"), r"\1"), + # dangling dashes (2 passes) + (re.compile(r"\s-\s"), r" "), + (re.compile(r"\s-\s"), r" "), + # special symbol with trailing dash + (re.compile(r"(\[.*?\])-"), r"\1"), + # Just remove all dashes + (re.compile(r"-"), r" "), + ] + + # unwanted symbols in the transcripts + self.remove_regexp_after = re.compile( + r"|".join( + [ + # remaining punctuation + r"\.", + r",", + r"\?", + r"{", + r"}", + r"~", + r"_\d", + ] + ) + ) + + self.post_fixes = [ + # Fix an issue related to [VOCALIZED NOISE] after dash removal + (re.compile(r"\[vocalized noise\]"), "[vocalized-noise]"), + ] + + self.whitespace_regexp = re.compile(r"\s+") + + def normalize(self, text: str) -> str: + text = text.lower() + + # first remove + text = self.remove_regexp_before.sub("", text) + + # then replace + for pattern, sub in self.replace_regexps: + text = pattern.sub(sub, text) + + # then remove + text = self.remove_regexp_after.sub("", text) + + # post fixes + for pattern, sub in self.post_fixes: + text = pattern.sub(sub, text) + + # then clean up whitespace + text = self.whitespace_regexp.sub(" ", text).strip() + + return text.upper() + + +def keep(sup: SupervisionSegment) -> bool: + if "((" in sup.text: + return False + + if " yes", + "[laugh] oh this is [laught] this is great [silence] yes", + "i don't kn- - know A.B.C's", + "so x. corp is good?", + "'absolutely yes", + "absolutely' yes", + "'absolutely' yes", + "'absolutely' yes 'aight", + "ABSOLUTE[LY]", + "ABSOLUTE[LY]-", + "[AN]Y", + "[AN]Y-", + "[ADV]AN[TAGE]", + "[ADV]AN[TAGE]-", + "-[ADV]AN[TAGE]", + "-[ADV]AN[TAGE]-", + "[WEA[SONABLE]-/REASONABLE]", + "[VOCALIZED-NOISE]-", + "~BULL", + "Frank E Peretti P E R E T T I", + "yeah yeah like Double O Seven he's supposed to do it", + "P A P E R paper", + "[noise] okay_1 um let me see [laughter] i've been sitting here awhile", + ]: + print(text) + print(normalizer.normalize(text)) + print() + + +if __name__ == "__main__": + test() + # exit() + main() diff --git a/egs/swbd/ASR/local/normalize_eval2000.py b/egs/swbd/ASR/local/normalize_eval2000.py new file mode 100755 index 000000000..7316193d0 --- /dev/null +++ b/egs/swbd/ASR/local/normalize_eval2000.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Nagendra Goel https://github.com/ngoel17) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import re +from typing import Tuple + +from lhotse import SupervisionSegment, SupervisionSet +from lhotse.serialization import load_manifest_lazy_or_eager +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input_sups") + parser.add_argument("output_sups") + return parser.parse_args() + + +def remove_punctutation_and_other_symbol(text: str) -> str: + text = text.replace("--", " ") + text = text.replace("//", " ") + text = text.replace(".", " ") + text = text.replace("?", " ") + text = text.replace("~", " ") + text = text.replace(",", " ") + text = text.replace(";", " ") + text = text.replace("(", " ") + text = text.replace(")", " ") + text = text.replace("&", " ") + text = text.replace("%", " ") + text = text.replace("*", " ") + text = text.replace("{", " ") + text = text.replace("}", " ") + return text + + +def eval2000_clean_eform(text: str, eform_count) -> str: + string_to_remove = [] + piece = text.split('">') + for i in range(0, len(piece)): + s = piece[i] + '">' + res = re.search(r"", s) + if res is not None: + res_rm = res.group(1) + string_to_remove.append(res_rm) + for p in string_to_remove: + eform_string = p + text = text.replace(eform_string, " ") + eform_1 = " str: + text = text.replace("[/BABY CRYING]", " ") + text = text.replace("[/CHILD]", " ") + text = text.replace("[[DISTORTED]]", " ") + text = text.replace("[/DISTORTION]", " ") + text = text.replace("[[DRAWN OUT]]", " ") + text = text.replace("[[DRAWN-OUT]]", " ") + text = text.replace("[[FAINT]]", " ") + text = text.replace("[SMACK]", " ") + text = text.replace("[[MUMBLES]]", " ") + text = text.replace("[[HIGH PITCHED SQUEAKY VOICE]]", " ") + text = text.replace("[[IN THE LAUGH]]", "[LAUGHTER]") + text = text.replace("[[LAST WORD SPOKEN WITH A LAUGH]]", "[LAUGHTER]") + text = text.replace("[[PART OF FIRST SYLLABLE OF PREVIOUS WORD CUT OFF]]", " ") + text = text.replace("[[PREVIOUS WORD SPOKEN WITH A LAUGH]]", " ") + text = text.replace("[[PREVIOUS TWO WORDS SPOKEN WHILE LAUGHING]]", " ") + text = text.replace("[[PROLONGED]]", " ") + text = text.replace("[/RUNNING WATER]", " ") + text = text.replace("[[SAYS LAUGHING]]", "[LAUGHTER]") + text = text.replace("[[SINGING]]", " ") + text = text.replace("[[SPOKEN WHILE LAUGHING]]", "[LAUGHTER]") + text = text.replace("[/STATIC]", " ") + text = text.replace("['THIRTIETH' DRAWN OUT]", " ") + text = text.replace("[/VOICES]", " ") + text = text.replace("[[WHISPERED]]", " ") + text = text.replace("[DISTORTION]", " ") + text = text.replace("[DISTORTION, HIGH VOLUME ON WAVES]", " ") + text = text.replace("[BACKGROUND LAUGHTER]", "[LAUGHTER]") + text = text.replace("[CHILD'S VOICE]", " ") + text = text.replace("[CHILD SCREAMS]", " ") + text = text.replace("[CHILD VOICE]", " ") + text = text.replace("[CHILD YELLING]", " ") + text = text.replace("[CHILD SCREAMING]", " ") + text = text.replace("[CHILD'S VOICE IN BACKGROUND]", " ") + text = text.replace("[CHANNEL NOISE]", " ") + text = text.replace("[CHANNEL ECHO]", " ") + text = text.replace("[ECHO FROM OTHER CHANNEL]", " ") + text = text.replace("[ECHO OF OTHER CHANNEL]", " ") + text = text.replace("[CLICK]", " ") + text = text.replace("[DISTORTED]", " ") + text = text.replace("[BABY CRYING]", " ") + text = text.replace("[METALLIC KNOCKING SOUND]", " ") + text = text.replace("[METALLIC SOUND]", " ") + + text = text.replace("[PHONE JIGGLING]", " ") + text = text.replace("[BACKGROUND SOUND]", " ") + text = text.replace("[BACKGROUND VOICE]", " ") + text = text.replace("[BACKGROUND VOICES]", " ") + text = text.replace("[BACKGROUND NOISE]", " ") + text = text.replace("[CAR HORNS IN BACKGROUND]", " ") + text = text.replace("[CAR HORNS]", " ") + text = text.replace("[CARNATING]", " ") + text = text.replace("[CRYING CHILD]", " ") + text = text.replace("[CHOPPING SOUND]", " ") + text = text.replace("[BANGING]", " ") + text = text.replace("[CLICKING NOISE]", " ") + text = text.replace("[CLATTERING]", " ") + text = text.replace("[ECHO]", " ") + text = text.replace("[KNOCK]", " ") + text = text.replace("[NOISE-GOOD]", "[NOISE]") + text = text.replace("[RIGHT]", " ") + text = text.replace("[SOUND]", " ") + text = text.replace("[SQUEAK]", " ") + text = text.replace("[STATIC]", " ") + text = text.replace("[[SAYS WITH HIGH-PITCHED SCREAMING LAUGHTER]]", " ") + text = text.replace("[UH]", "UH") + text = text.replace("[MN]", "[VOCALIZED-NOISE]") + text = text.replace("[VOICES]", " ") + text = text.replace("[WATER RUNNING]", " ") + text = text.replace("[SOUND OF TWISTING PHONE CORD]", " ") + text = text.replace("[SOUND OF SOMETHING FALLING]", " ") + text = text.replace("[SOUND]", " ") + text = text.replace("[NOISE OF MOVING PHONE]", " ") + text = text.replace("[SOUND OF RUNNING WATER]", " ") + text = text.replace("[CHANNEL]", " ") + text = text.replace("[SILENCE]", " ") + text = text.replace("-[W]HERE", "WHERE") + text = text.replace("Y[OU]I-", "YOU I") + text = text.replace("-[A]ND", "AND") + text = text.replace("JU[ST]", "JUST") + text = text.replace("{BREATH}", " ") + text = text.replace("{BREATHY}", " ") + text = text.replace("{CHANNEL NOISE}", " ") + text = text.replace("{CLEAR THROAT}", " ") + + text = text.replace("{CLEARING THROAT}", " ") + text = text.replace("{CLEARS THROAT}", " ") + text = text.replace("{COUGH}", " ") + text = text.replace("{DRAWN OUT}", " ") + text = text.replace("{EXHALATION}", " ") + text = text.replace("{EXHALE}", " ") + text = text.replace("{GASP}", " ") + text = text.replace("{HIGH SQUEAL}", " ") + text = text.replace("{INHALE}", " ") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LIPSMACK}", " ") + text = text.replace("{LIPSMACK}", " ") + + text = text.replace("{NOISE OF DISGUST}", " ") + text = text.replace("{SIGH}", " ") + text = text.replace("{SNIFF}", " ") + text = text.replace("{SNORT}", " ") + text = text.replace("{SHARP EXHALATION}", " ") + text = text.replace("{BREATH LAUGH}", " ") + + text = text.replace("[LAUGHTER]", " ") + text = text.replace("[NOISE]", " ") + text = text.replace("[VOCALIZED-NOISE]", " ") + text = text.replace("-", " ") + return text + + +def remove_languagetag(text: str) -> str: + langtag = re.findall(r"<(.*?)>", text) + for t in langtag: + text = text.replace(t, " ") + text = text.replace("<", " ") + text = text.replace(">", " ") + return text + + +def eval2000_normalizer(text: str) -> str: + # print("TEXT original: ",text) + eform_count = text.count("contraction e_form") + # print("eform corunt:", eform_count) + if eform_count > 0: + text = eval2000_clean_eform(text, eform_count) + text = text.upper() + text = remove_languagetag(text) + text = replace_silphone(text) + text = remove_punctutation_and_other_symbol(text) + text = text.replace("IGNORE_TIME_SEGMENT_IN_SCORING", " ") + text = text.replace("IGNORE_TIME_SEGMENT_SCORING", " ") + spaces = re.findall(r"\s+", text) + for sp in spaces: + text = text.replace(sp, " ") + text = text.strip() + # text = self.whitespace_regexp.sub(" ", text).strip() + # print(text) + return text + + +def main(): + args = get_args() + sups = load_manifest_lazy_or_eager(args.input_sups) + assert isinstance(sups, SupervisionSet) + + tot, skip = 0, 0 + with SupervisionSet.open_writer(args.output_sups) as writer: + for sup in tqdm(sups, desc="Normalizing supervisions"): + tot += 1 + sup.text = eval2000_normalizer(sup.text) + if not sup.text: + skip += 1 + continue + writer.write(sup) + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/prepare_lang.py b/egs/swbd/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/prepare_lang_bpe.py b/egs/swbd/ASR/local/prepare_lang_bpe.py new file mode 100755 index 000000000..d82a085ec --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lang_bpe.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +""" + +This script takes as input `lang_dir`, which should contain:: + + - lang_dir/bpe.model, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import sentencepiece as spm +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + +from icefall.utils import str2bool + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def generate_lexicon( + model_file: str, words: List[str], oov: str +) -> Tuple[Lexicon, Dict[str, int]]: + """Generate a lexicon from a BPE model. + + Args: + model_file: + Path to a sentencepiece model. + words: + A list of strings representing words. + oov: + The out of vocabulary word in lexicon. + Returns: + Return a tuple with two elements: + - A dict whose keys are words and values are the corresponding + word pieces. + - A dict representing the token symbol, mapping from tokens to IDs. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + + # Convert word to word piece IDs instead of word piece strings + # to avoid OOV tokens. + words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) + + # Now convert word piece IDs back to word piece strings. + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + + lexicon = [] + for word, pieces in zip(words, words_pieces): + lexicon.append((word, pieces)) + + lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) + + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} + + return lexicon, token2id + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + parser.add_argument( + "--oov", + type=str, + default="", + help="The out of vocabulary word in lexicon.", + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + + See "test/test_bpe_lexicon.py" for usage. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + model_file = lang_dir / "bpe.model" + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = [ + "", + "!SIL", + "", + args.oov, + "#0", + "", + "", + ] + + for w in excluded: + if w in words: + words.remove(w) + + lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/prepare_lm_training_data.py b/egs/swbd/ASR/local/prepare_lm_training_data.py new file mode 120000 index 000000000..abc00d421 --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/rt03_data_prep.sh b/egs/swbd/ASR/local/rt03_data_prep.sh new file mode 100755 index 000000000..8a5f64324 --- /dev/null +++ b/egs/swbd/ASR/local/rt03_data_prep.sh @@ -0,0 +1,107 @@ +#!/usr/bin/env bash + +# RT-03 data preparation (conversational telephone speech part only) +# Adapted from Arnab Ghoshal's script for Hub-5 Eval 2000 by Peng Qi + +# To be run from one directory above this script. + +# Expects the standard directory layout for RT-03 + +if [ $# -ne 1 ]; then + echo "Usage: $0 " + echo "e.g.: $0 /export/corpora/LDC/LDC2007S10" + echo "See comments in the script for more details" + exit 1 +fi + +sdir=$1 +[ ! -d $sdir/data/audio/eval03/english/cts ] && + echo Expecting directory $sdir/data/audio/eval03/english/cts to be present && exit 1 +[ ! -d $sdir/data/references/eval03/english/cts ] && + echo Expecting directory $tdir/data/references/eval03/english/cts to be present && exit 1 + +dir=data/local/rt03 +mkdir -p $dir + +rtroot=$sdir +tdir=$sdir/data/references/eval03/english/cts +sdir=$sdir/data/audio/eval03/english/cts + +find -L $sdir -iname '*.sph' | sort >$dir/sph.flist +sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ + >$dir/sph.scp + +sph2pipe=sph2pipe +! command -v "${sph2pipe}" &>/dev/null && + echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1 + +awk -v sph2pipe=$sph2pipe '{ + printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); + printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); +}' <$dir/sph.scp | sort >$dir/wav.scp || exit 1 +#side A - channel 1, side B - channel 2 + +# Get segments file... +# segments file format is: utt-id side-id start-time end-time, e.g.: +# sw02001-A_000098-001156 sw02001-A 0.98 11.56 +#pem=$sdir/english/hub5e_00.pem +#[ ! -f $pem ] && echo "No such file $pem" && exit 1; +# pem file has lines like: +# en_4156 A unknown_speaker 301.85 302.48 + +#grep -v ';;' $pem \ +cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap | + awk '{ + spk=$1"-"(($2==1)?"A":"B"); + utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); + print utt,spk,$4,$5;}' | + sort -u >$dir/segments + +# stm file has lines like: +# en_4156 A en_4156_A 357.64 359.64 HE IS A POLICE OFFICER +# TODO(arnab): We should really be lowercasing this since the Edinburgh +# recipe uses lowercase. This is not used in the actual scoring. +#grep -v ';;' $tdir/reference/hub5e00.english.000405.stm \ +cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap | + awk '{ + spk=$1"-"(($2==1)?"A":"B"); + utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); + printf utt; for(n=7;n<=NF;n++) printf(" %s", $n); print ""; }' | + sort >$dir/text.all + +# We'll use the stm file for sclite scoring. There seem to be various errors +# in the stm file that upset hubscr.pl, and we fix them here. +cat $tdir/*.stm | + sed -e 's:((:(:' -e 's:::g' -e 's:::g' | + grep -v inter_segment_gap | + awk '{ + printf $1; if ($1==";;") printf(" %s",$2); else printf(($2==1)?" A":" B"); for(n=3;n<=NF;n++) printf(" %s", $n); print ""; }' \ + >$dir/stm +#$tdir/reference/hub5e00.english.000405.stm > $dir/stm +cp $rtroot/data/trans_rules/en20030506.glm $dir/glm + +# next line uses command substitution +# Just checking that the segments are the same in pem vs. stm. +! cmp <(awk '{print $1}' $dir/text.all) <(awk '{print $1}' $dir/segments) && + echo "Segments from pem file and stm file do not match." && exit 1 + +grep -v IGNORE_TIME_SEGMENT_ $dir/text.all >$dir/text + +# create an utt2spk file that assumes each conversation side is +# a separate speaker. +awk '{print $1,$2;}' $dir/segments >$dir/utt2spk +utils/utt2spk_to_spk2utt.pl $dir/utt2spk >$dir/spk2utt + +# cp $dir/segments $dir/segments.tmp +# awk '{x=$3-0.05; if (x<0.0) x=0.0; y=$4+0.05; print $1, $2, x, y; }' \ +# $dir/segments.tmp > $dir/segments + +awk '{print $1}' $dir/wav.scp | + perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; + print "$1-$2 $1 $2\n"; ' \ + >$dir/reco2file_and_channel || exit 1 + +./utils/fix_data_dir.sh $dir + +echo Data preparation and formatting completed for RT-03 +echo "(but not MFCC extraction)" diff --git a/egs/swbd/ASR/local/sort_lm_training_data.py b/egs/swbd/ASR/local/sort_lm_training_data.py new file mode 100755 index 000000000..bed3856e4 --- /dev/null +++ b/egs/swbd/ASR/local/sort_lm_training_data.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input the filename of LM training data +generated by ./local/prepare_lm_training_data.py and sorts +it by sentence length. + +Sentence length equals to the number of BPE tokens in a sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import numpy as np +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--in-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/lm_data.pt", + ) + + parser.add_argument( + "--out-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/sorted_lm_data.pt", + ) + + parser.add_argument( + "--out-statistics", + type=str, + help="Statistics about LM training data., data/bpe_500/statistics.txt", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + in_lm_data = Path(args.in_lm_data) + out_lm_data = Path(args.out_lm_data) + assert in_lm_data.is_file(), f"{in_lm_data}" + if out_lm_data.is_file(): + logging.warning(f"{out_lm_data} exists - skipping") + return + data = torch.load(in_lm_data) + words2bpe = data["words"] + sentences = data["sentences"] + sentence_lengths = data["sentence_lengths"] + + num_sentences = sentences.dim0 + assert num_sentences == sentence_lengths.numel(), ( + num_sentences, + sentence_lengths.numel(), + ) + + indices = torch.argsort(sentence_lengths, descending=True) + + sorted_sentences = sentences[indices.to(torch.int32)] + sorted_sentence_lengths = sentence_lengths[indices] + + # Check that sentences are ordered by length + assert num_sentences == sorted_sentences.dim0, ( + num_sentences, + sorted_sentences.dim0, + ) + + cur = None + for i in range(num_sentences): + word_ids = sorted_sentences[i] + token_ids = words2bpe[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + if cur is not None: + assert cur >= token_ids.numel(), (cur, token_ids.numel()) + + cur = token_ids.numel() + assert cur == sorted_sentence_lengths[i] + + data["sentences"] = sorted_sentences + data["sentence_lengths"] = sorted_sentence_lengths + torch.save(data, args.out_lm_data) + logging.info(f"Saved to {args.out_lm_data}") + + statistics = Path(args.out_statistics) + + # Write statistics + num_words = sorted_sentences.numel() + num_tokens = sentence_lengths.sum().item() + max_sentence_length = sentence_lengths[indices[0]] + min_sentence_length = sentence_lengths[indices[-1]] + + step = 10 + hist, bins = np.histogram( + sentence_lengths.numpy(), + bins=np.arange(1, max_sentence_length + step, step), + ) + + histogram = np.stack((bins[:-1], hist)).transpose() + + with open(statistics, "w") as f: + f.write(f"num_sentences: {num_sentences}\n") + f.write(f"num_words: {num_words}\n") + f.write(f"num_tokens: {num_tokens}\n") + f.write(f"max_sentence_length: {max_sentence_length}\n") + f.write(f"min_sentence_length: {min_sentence_length}\n") + f.write("histogram:\n") + f.write(" bin count percent\n") + for row in histogram: + f.write( + f"{int(row[0]):>5} {int(row[1]):>5} " + f"{100.*row[1]/num_sentences:.3f}%\n" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/local/swbd1_data_prep.sh b/egs/swbd/ASR/local/swbd1_data_prep.sh new file mode 100755 index 000000000..159359491 --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_data_prep.sh @@ -0,0 +1,128 @@ +#!/usr/bin/env bash + +# Switchboard-1 training data preparation customized for Edinburgh +# Author: Arnab Ghoshal (Jan 2013) + +# To be run from one directory above this script. + +## The input is some directory containing the switchboard-1 release 2 +## corpus (LDC97S62). Note: we don't make many assumptions about how +## you unpacked this. We are just doing a "find" command to locate +## the .sph files. + +## The second input is optional, which should point to a directory containing +## Switchboard transcriptions/documentations (specifically, the conv.tab file). +## If specified, the script will try to use the actual speaker PINs provided +## with the corpus instead of the conversation side ID (Kaldi default). We +## will be using "find" to locate this file so we don't make any assumptions +## on the directory structure. (Peng Qi, Aug 2014) + +#check existing directories +if [ $# != 1 -a $# != 2 ]; then + echo "Usage: swbd1_data_prep.sh /path/to/SWBD [/path/to/SWBD_DOC]" + exit 1 +fi + +SWBD_DIR=$1 + +dir=data/local/train +mkdir -p $dir + +# Audio data directory check +if [ ! -d $SWBD_DIR ]; then + echo "Error: run.sh requires a directory argument" + exit 1 +fi + +sph2pipe=sph2pipe +! command -v "${sph2pipe}" &>/dev/null && + echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1 + +# Option A: SWBD dictionary file check +[ ! -f ./swb_ms98_transcriptions/sw-ms98-dict.text ] && + echo "SWBD dictionary file does not exist" && exit 1 + +# find sph audio files +find -L $SWBD_DIR -iname '*.sph' | sort >$dir/sph.flist + +n=$(cat $dir/sph.flist | wc -l) +[ $n -ne 2435 ] && [ $n -ne 2438 ] && + echo Warning: expected 2435 or 2438 data data files, found $n + +# (1a) Transcriptions preparation +# make basic transcription file (add segments info) +# **NOTE: In the default Kaldi recipe, everything is made uppercase, while we +# make everything lowercase here. This is because we will be using SRILM which +# can optionally make everything lowercase (but not uppercase) when mapping +# LM vocabs. +awk '{ +name=substr($1,1,6); gsub("^sw","sw0",name); side=substr($1,7,1); +stime=$2; etime=$3; +printf("%s-%s_%06.0f-%06.0f", +name, side, int(100*stime+0.5), int(100*etime+0.5)); +for(i=4;i<=NF;i++) printf(" %s", $i); printf "\n" +}' ./swb_ms98_transcriptions/*/*/*-trans.text >$dir/transcripts1.txt + +# test if trans. file is sorted +export LC_ALL=C +sort -c $dir/transcripts1.txt || exit 1 # check it's sorted. + +# Remove SILENCE, and . + +# Note: we have [NOISE], [VOCALIZED-NOISE], [LAUGHTER], [SILENCE]. +# removing [SILENCE], and the and markers that mark +# speech to somone; we will give phones to the other three (NSN, SPN, LAU). +# There will also be a silence phone, SIL. +# **NOTE: modified the pattern matches to make them case insensitive +cat $dir/transcripts1.txt | + perl -ane 's:\s\[SILENCE\](\s|$):$1:gi; + s///gi; + s///gi; + print;' | + awk '{if(NF > 1) { print; } } ' >$dir/transcripts2.txt + +# **NOTE: swbd1_map_words.pl has been modified to make the pattern matches +# case insensitive +local/swbd1_map_words.pl -f 2- $dir/transcripts2.txt >$dir/text + +# format acronyms in text +python3 local/map_acronyms_transcripts.py -i $dir/text -o $dir/text_map \ + -M data/local/dict_nosp/acronyms.map +mv $dir/text_map $dir/text + +# (1c) Make segment files from transcript +#segments file format is: utt-id side-id start-time end-time, e.g.: +#sw02001-A_000098-001156 sw02001-A 0.98 11.56 +awk '{ +segment=$1; +split(segment,S,"[_-]"); +side=S[2]; audioname=S[1]; startf=S[3]; endf=S[4]; +print segment " " audioname "-" side " " startf/100 " " endf/100 +}' <$dir/text >$dir/segments + +sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ + >$dir/sph.scp + +awk -v sph2pipe=$sph2pipe '{ +printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); +printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); +}' <$dir/sph.scp | sort >$dir/wav.scp || exit 1 +#side A - channel 1, side B - channel 2 + +# this file reco2file_and_channel maps recording-id (e.g. sw02001-A) +# to the file name sw02001 and the A, e.g. +# sw02001-A sw02001 A +# In this case it's trivial, but in other corpora the information might +# be less obvious. Later it will be needed for ctm scoring. +awk '{print $1}' $dir/wav.scp | + perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; + print "$1-$2 $1 $2\n"; ' \ + >$dir/reco2file_and_channel || exit 1 + +awk '{spk=substr($1,1,9); print $1 " " spk}' $dir/segments >$dir/utt2spk || + exit 1 +sort -k 2 $dir/utt2spk | utils/utt2spk_to_spk2utt.pl >$dir/spk2utt || exit 1 + +echo Switchboard-1 data preparation succeeded. + +utils/fix_data_dir.sh data/local/train diff --git a/egs/swbd/ASR/local/swbd1_map_words.pl b/egs/swbd/ASR/local/swbd1_map_words.pl new file mode 100755 index 000000000..4fb8d4ffe --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_map_words.pl @@ -0,0 +1,52 @@ +#!/usr/bin/env perl + +# Modified from swbd_map_words.pl in Kaldi s5 recipe to make pattern +# matches case-insensitive --Arnab (Jan 2013) + +if ($ARGV[0] eq "-f") { + shift @ARGV; + $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; + } + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesy (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; + } +} + + +while (<>) { + @A = split(" ", $_); + for ($n = 0; $n < @A; $n++) { + $a = $A[$n]; + if ( (!defined $field_begin || $n >= $field_begin) + && (!defined $field_end || $n <= $field_end)) { + # e.g. [LAUGHTER-STORY] -> STORY; + $a =~ s:(|\-)^\[LAUGHTER-(.+)\](|\-)$:$1$2$3:i; + # $1 and $3 relate to preserving trailing "-" + $a =~ s:^\[(.+)/.+\](|\-)$:$1$2:; # e.g. [IT'N/ISN'T] -> IT'N ... note, + # 1st part may include partial-word stuff, which we process further below, + # e.g. [LEM[GUINI]-/LINGUINI] + # the (|\_) at the end is to accept and preserve trailing -'s. + $a =~ s:^(|\-)\[[^][]+\](.+)$:-$2:; # e.g. -[AN]Y , note \047 is quote; + # let the leading - be optional on input, as sometimes omitted. + $a =~ s:^(.+)\[[^][]+\](|\-)$:$1-:; # e.g. AB[SOLUTE]- -> AB-; + # let the trailing - be optional on input, as sometimes omitted. + $a =~ s:([^][]+)\[.+\]$:$1:; # e.g. EX[SPECIALLY]-/ESPECIALLY] -> EX- + # which is a mistake in the input. + $a =~ s:^\{(.+)\}$:$1:; # e.g. {YUPPIEDOM} -> YUPPIEDOM + $a =~ s:[A-Z]\[([^][])+\][A-Z]:$1-$3:i; # e.g. AMMU[N]IT- -> AMMU-IT- + $a =~ s:_\d$::; # e.g. THEM_1 -> THEM + } + $A[$n] = $a; + } + print join(" ", @A) . "\n"; +} diff --git a/egs/swbd/ASR/local/swbd1_prepare_dict.sh b/egs/swbd/ASR/local/swbd1_prepare_dict.sh new file mode 100755 index 000000000..eff5fb5f1 --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_prepare_dict.sh @@ -0,0 +1,101 @@ +#!/usr/bin/env bash + +# Formatting the Mississippi State dictionary for use in Edinburgh. Differs +# from the one in Kaldi s5 recipe in that it uses lower-case --Arnab (Jan 2013) + +# To be run from one directory above this script. + +#check existing directories +[ $# != 0 ] && echo "Usage: local/swbd1_data_prep.sh" && exit 1 + +srcdir=. # This is where we downloaded some stuff.. +dir=./data/local/dict_nosp +mkdir -p $dir +srcdict=$srcdir/swb_ms98_transcriptions/sw-ms98-dict.text + +# assume swbd_p1_data_prep.sh was done already. +[ ! -f "$srcdict" ] && echo "$0: No such file $srcdict" && exit 1 + +cp $srcdict $dir/lexicon0.txt || exit 1 +chmod a+w $dir/lexicon0.txt +patch 0' | sort >$dir/lexicon1.txt || exit 1 + +cat $dir/lexicon1.txt | awk '{ for(n=2;n<=NF;n++){ phones[$n] = 1; }} END{for (p in phones) print p;}' | + grep -v sil >$dir/nonsilence_phones.txt || exit 1 + +( + echo sil + echo spn + echo nsn + echo lau +) >$dir/silence_phones.txt + +echo sil >$dir/optional_silence.txt + +# No "extra questions" in the input to this setup, as we don't +# have stress or tone. +echo -n >$dir/extra_questions.txt + +cp local/MSU_single_letter.txt $dir/ +# Add to the lexicon the silences, noises etc. +# Add single letter lexicon +# The original swbd lexicon does not have precise single letter lexicion +# e.g. it does not have entry of W +( + echo '!SIL SIL' + echo '[VOCALIZED-NOISE] spn' + echo '[NOISE] nsn' + echo '[LAUGHTER] lau' + echo ' spn' +) | + cat - $dir/lexicon1.txt $dir/MSU_single_letter.txt >$dir/lexicon2.txt || exit 1 + +# Map the words in the lexicon. That is-- for each word in the lexicon, we map it +# to a new written form. The transformations we do are: +# remove laughter markings, e.g. +# [LAUGHTER-STORY] -> STORY +# Remove partial-words, e.g. +# -[40]1K W AH N K EY +# becomes -1K +# and +# -[AN]Y IY +# becomes +# -Y +# -[A]B[OUT]- B +# becomes +# -B- +# Also, curly braces, which appear to be used for "nonstandard" +# words or non-words, are removed, e.g. +# {WOLMANIZED} W OW L M AX N AY Z D +# -> WOLMANIZED +# Also, mispronounced words, e.g. +# [YEAM/YEAH] Y AE M +# are changed to just e.g. YEAM, i.e. the orthography +# of the mispronounced version. +# Note-- this is only really to be used in training. The main practical +# reason is to avoid having tons of disambiguation symbols, which +# we otherwise would get because there are many partial words with +# the same phone sequences (most problematic: S). +# Also, map +# THEM_1 EH M -> THEM +# so that multiple pronunciations just have alternate entries +# in the lexicon. + +local/swbd1_map_words.pl -f 1 $dir/lexicon2.txt | sort -u \ + >$dir/lexicon3.txt || exit 1 + +python3 local/format_acronyms_dict.py -i $dir/lexicon3.txt -o $dir/lexicon4.txt \ + -L $dir/MSU_single_letter.txt -M $dir/acronyms_raw.map +cat $dir/acronyms_raw.map | sort -u >$dir/acronyms.map + +(echo 'i ay') | cat - $dir/lexicon4.txt | tr '[A-Z]' '[a-z]' | sort -u >$dir/lexicon5.txt + +pushd $dir >&/dev/null +ln -sf lexicon5.txt lexicon.txt # This is the final lexicon. +popd >&/dev/null +rm $dir/lexiconp.txt 2>/dev/null +echo Prepared input dictionary and phone-sets for Switchboard phase 1. diff --git a/egs/swbd/ASR/local/train_bpe_model.py b/egs/swbd/ASR/local/train_bpe_model.py new file mode 100755 index 000000000..9b4e28635 --- /dev/null +++ b/egs/swbd/ASR/local/train_bpe_model.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + user_defined_symbols += ["[LAUGHTER]", "[NOISE]", "[VOCALIZED-NOISE]"] + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/validate_bpe_lexicon.py b/egs/swbd/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/swbd/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh new file mode 100755 index 000000000..47d12613b --- /dev/null +++ b/egs/swbd/ASR/prepare.sh @@ -0,0 +1,463 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. Most of them can't be downloaded automatically +# as they are not publically available and require a license purchased +# from the LDC. +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=./download +# swbd1_dir="/export/corpora3/LDC/LDC97S62" +swbd1_dir=./download/LDC97S62/ + +# eval2000_dir contains the following files and directories +# downloaded from LDC website: +# - LDC2002S09 +# - hub5e_00 +# - LDC2002T43 +# - reference +eval2000_dir="/export/corpora2/LDC/eval2000" + +rt03_dir="/export/corpora/LDC/LDC2007S10" +fisher_dir="/export/corpora3/LDC/LDC2004T19" + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + 1000 + 500 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "swbd1_dir: $swbd1_dir" +log "eval2000_dir: $eval2000_dir" +log "rt03_dir: $rt03_dir" + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare SwitchBoard manifest" + # We assume that you have downloaded the SwitchBoard corpus + # to respective dirs + mkdir -p data/manifests + if [ ! -e data/manifests/.swbd.done ]; then + lhotse prepare switchboard --absolute-paths 1 --omit-silence $swbd1_dir data/manifests/swbd + ./local/normalize_and_filter_supervisions.py \ + data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz + mv data/manifests/swbd/swbd_supervisions_all.jsonl.gz data/manifests/swbd/swbd_supervisions_orig.jsonl.gz + mv data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz data/manifests/swbd/swbd_supervisions_all.jsonl.gz + + lhotse cut simple \ + -r data/manifests/swbd/swbd_recordings_all.jsonl.gz \ + -s data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + data/manifests/swbd/swbd_train_all.jsonl.gz + lhotse cut trim-to-supervisions \ + --discard-overlapping \ + --discard-extra-channels \ + data/manifests/swbd/swbd_train_all.jsonl.gz \ + data/manifests/swbd/swbd_train_all_trimmed.jsonl.gz + + num_splits=16 + mkdir -p data/manifests/swbd_split${num_splits} + lhotse split ${num_splits} \ + data/manifests/swbd/swbd_train_all_trimmed.jsonl.gz \ + data/manifests/swbd_split${num_splits} + + lhotse prepare eval2000 --absolute-paths 1 $eval2000_dir data/manifests/eval2000 + ./local/normalize_eval2000.py \ + data/manifests/eval2000/eval2000_supervisions_unnorm.jsonl.gz \ + data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz + + lhotse cut simple \ + -r data/manifests/eval2000/eval2000_recordings_all.jsonl.gz \ + -s data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz \ + data/manifests/eval2000/eval2000_cuts_all.jsonl.gz + + lhotse cut trim-to-supervisions \ + --discard-overlapping \ + --discard-extra-channels \ + data/manifests/eval2000/eval2000_cuts_all.jsonl.gz \ + data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz + + sed -e 's:((:(:' -e 's:::g' -e 's:::g' \ + $eval2000_dir/LDC2002T43/reference/hub5e00.english.000405.stm > data/manifests/eval2000/stm + cp $eval2000_dir/LDC2002T43/reference/en20000405_hub5.glm $dir/glm + + # ./local/rt03_data_prep.sh $rt03_dir + + # normalize eval2000 and rt03 texts by + # 1) convert upper to lower + # 2) remove tags (%AH) (%HESITATION) (%UH) + # 3) remove + # 4) remove "(" or ")" + # for x in rt03; do + # cp data/local/${x}/text data/local/${x}/text.org + # paste -d "" \ + # <(cut -f 1 -d" " data/local/${x}/text.org) \ + # <(awk '{$1=""; print tolower($0)}' data/local/${x}/text.org | perl -pe 's| \(\%.*\)||g' | perl -pe 's| \<.*\>||g' | sed -e "s/(//g" -e "s/)//g") | + # sed -e 's/\s\+/ /g' >data/local/${x}/text + # rm data/local/${x}/text.org + # done + + # lhotse fix data/manifests_rt03/swbd_recordings_rt03.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz data/manifests + + touch data/manifests/.swbd.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3 I: Compute fbank for SwitchBoard" + if [ ! -e data/fbank/.swbd.done ]; then + mkdir -p data/fbank/swbd_split${num_splits}/ + for index in $(seq 1 16); do + ./local/compute_fbank_swbd.py --split-index ${index} & + done + wait + pieces=$(find data/fbank/swbd_split${num_splits} -name "swbd_cuts_all.*.jsonl.gz") + lhotse combine $pieces data/fbank/swbd_cuts_all.jsonl.gz + touch data/fbank/.swbd.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3 II: Compute fbank for eval2000" + if [ ! -e data/fbank/.eval2000.done ]; then + mkdir -p data/fbank/eval2000/ + ./local/compute_fbank_eval2000.py + touch data/fbank/.eval2000.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare phone based lang" + lang_dir=data/lang_phone + mkdir -p $lang_dir + + if ! which jq; then + echo "This script is intended to be used with jq but you have not installed jq + Note: in Linux, you can install jq with the following command: + 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 + 2. chmod +x ./jq + 3. cp jq /usr/bin" && exit 1 + fi + if [ ! -f $lang_dir/text ] || [ ! -s $lang_dir/text ]; then + log "Prepare text." + gunzip -c data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + | jq '.text' | sed 's/"//g' > $lang_dir/text + fi + + log "Prepare dict" + ./local/swbd1_prepare_dict.sh + cut -f 2- -d" " $lang_dir/text >${lang_dir}/input.txt + # [noise] nsn + # !sil sil + # spn + cat data/local/dict_nosp/lexicon.txt | sed 's/-//g' | sed 's/\[vocalizednoise\]/\[vocalized-noise\]/g' | + sort | uniq >$lang_dir/lexicon_lower.txt + + cat $lang_dir/lexicon_lower.txt | tr a-z A-Z > $lang_dir/lexicon.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + + cat data/lang_phone/text | cut -d " " -f 2- >$lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare bigram token-level P for MMI training" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ + --oov "" \ + >$lang_dir/transcript_tokens.txt + fi + + if [ ! -f $lang_dir/P.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 2 \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/P.arpa + fi + + if [ ! -f $lang_dir/P.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + $lang_dir/P.arpa >$lang_dir/P.fst.txt + fi + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare G" + lang_dir=data/lang_phone + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text ${lang_dir}/input.txt \ + -lm data/lm/3-gram.arpa + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/lm/3-gram.arpa >data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + ./shared/make_kn_lm.py \ + -ngram-order 4 \ + -text ${lang_dir}/input.txt \ + -lm data/lm/4-gram.arpa + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + data/lm/4-gram.arpa >data/lm/G_4_gram.fst.txt + fi +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir + done +fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Compile LG" + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Generate LM training data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/train.txt ]; then + tail -n 250000 data/lang_phone/input.txt > $out_dir/train.txt + fi + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data data/lang_phone/input.txt \ + --lm-archive $out_dir/lm_data.pt + done +fi + +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Generate LM validation data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/valid.txt ]; then + head -n 14332 data/lang_phone/input.txt > $out_dir/valid.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + done +fi + +if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then + log "Stage 13: Generate LM test data" + testsets=(eval2000) + + for testset in ${testsets[@]}; do + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/${testset}.txt ]; then + gunzip -c data/manifests/${testset}/eval2000_supervisions_all.jsonl.gz \ + | jq '.text' | sed 's/"//g' > $out_dir/${testset}.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/${testset}.txt \ + --lm-archive $out_dir/lm_data-${testset}.pt + done + done +fi + +if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then + log "Stage 14: Sort LM training data" + testsets=(eval2000) + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + for testset in ${testsets[@]}; do + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-${testset}.pt \ + --out-lm-data $out_dir/sorted_lm_data-${testset}.pt \ + --out-statistics $out_dir/statistics-test-${testset}.txt + done + done +fi diff --git a/egs/swbd/ASR/shared b/egs/swbd/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/swbd/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/swbd/ASR/utils/filter_scp.pl b/egs/swbd/ASR/utils/filter_scp.pl new file mode 100755 index 000000000..b76d37f41 --- /dev/null +++ b/egs/swbd/ASR/utils/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: utils/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/egs/swbd/ASR/utils/fix_data_dir.sh b/egs/swbd/ASR/utils/fix_data_dir.sh new file mode 100755 index 000000000..ca0972ca8 --- /dev/null +++ b/egs/swbd/ASR/utils/fix_data_dir.sh @@ -0,0 +1,197 @@ +#!/bin/bash + +# This script makes sure that only the segments present in +# all of "feats.scp", "wav.scp" [if present], segments [if present] +# text, and utt2spk are present in any of them. +# It puts the original contents of data-dir into +# data-dir/.backup + +cmd="$@" + +utt_extra_files= +spk_extra_files= + +. utils/parse_options.sh + +if [ $# != 1 ]; then + echo "Usage: utils/data/fix_data_dir.sh " + echo "e.g.: utils/data/fix_data_dir.sh data/train" + echo "This script helps ensure that the various files in a data directory" + echo "are correctly sorted and filtered, for example removing utterances" + echo "that have no features (if feats.scp is present)" + exit 1 +fi + +data=$1 + +if [ -f $data/images.scp ]; then + image/fix_data_dir.sh $cmd + exit $? +fi + +mkdir -p $data/.backup + +[ ! -d $data ] && echo "$0: no such directory $data" && exit 1; + +[ ! -f $data/utt2spk ] && echo "$0: no such file $data/utt2spk" && exit 1; + +set -e -o pipefail -u + +tmpdir=$(mktemp -d /tmp/kaldi.XXXX); +trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM + +export LC_ALL=C + +function check_sorted { + file=$1 + sort -k1,1 -u <$file >$file.tmp + if ! cmp -s $file $file.tmp; then + echo "$0: file $1 is not in sorted order or not unique, sorting it" + mv $file.tmp $file + else + rm $file.tmp + fi +} + +for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp \ + reco2file_and_channel spk2gender utt2lang utt2uniq utt2dur reco2dur utt2num_frames; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + check_sorted $data/$x + fi +done + + +function filter_file { + filter=$1 + file_to_filter=$2 + cp $file_to_filter ${file_to_filter}.tmp + utils/filter_scp.pl $filter ${file_to_filter}.tmp > $file_to_filter + if ! cmp ${file_to_filter}.tmp $file_to_filter >&/dev/null; then + length1=$(cat ${file_to_filter}.tmp | wc -l) + length2=$(cat ${file_to_filter} | wc -l) + if [ $length1 -ne $length2 ]; then + echo "$0: filtered $file_to_filter from $length1 to $length2 lines based on filter $filter." + fi + fi + rm $file_to_filter.tmp +} + +function filter_recordings { + # We call this once before the stage when we filter on utterance-id, and once + # after. + + if [ -f $data/segments ]; then + # We have a segments file -> we need to filter this and the file wav.scp, and + # reco2file_and_utt, if it exists, to make sure they have the same list of + # recording-ids. + + if [ ! -f $data/wav.scp ]; then + echo "$0: $data/segments exists but not $data/wav.scp" + exit 1; + fi + awk '{print $2}' < $data/segments | sort | uniq > $tmpdir/recordings + n1=$(cat $tmpdir/recordings | wc -l) + [ ! -s $tmpdir/recordings ] && \ + echo "Empty list of recordings (bad file $data/segments)?" && exit 1; + utils/filter_scp.pl $data/wav.scp $tmpdir/recordings > $tmpdir/recordings.tmp + mv $tmpdir/recordings.tmp $tmpdir/recordings + + + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + filter_file $tmpdir/recordings $data/segments + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + rm $data/segments.tmp + + filter_file $tmpdir/recordings $data/wav.scp + [ -f $data/reco2file_and_channel ] && filter_file $tmpdir/recordings $data/reco2file_and_channel + [ -f $data/reco2dur ] && filter_file $tmpdir/recordings $data/reco2dur + true + fi +} + +function filter_speakers { + # throughout this program, we regard utt2spk as primary and spk2utt as derived, so... + utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + + cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers + for s in cmvn.scp spk2gender; do + f=$data/$s + if [ -f $f ]; then + filter_file $f $tmpdir/speakers + fi + done + + filter_file $tmpdir/speakers $data/spk2utt + utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk + + for s in cmvn.scp spk2gender $spk_extra_files; do + f=$data/$s + if [ -f $f ]; then + filter_file $tmpdir/speakers $f + fi + done +} + +function filter_utts { + cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts + + ! cat $data/utt2spk | sort | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order (fix this yourself)" && exit 1; + + ! cat $data/utt2spk | sort -k2 | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order when sorted first on speaker-id " && \ + echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1; + + ! cat $data/spk2utt | sort | cmp - $data/spk2utt && \ + echo "spk2utt is not in sorted order (fix this yourself)" && exit 1; + + if [ -f $data/utt2uniq ]; then + ! cat $data/utt2uniq | sort | cmp - $data/utt2uniq && \ + echo "utt2uniq is not in sorted order (fix this yourself)" && exit 1; + fi + + maybe_wav= + maybe_reco2dur= + [ ! -f $data/segments ] && maybe_wav=wav.scp # wav indexed by utts only if segments does not exist. + [ -s $data/reco2dur ] && [ ! -f $data/segments ] && maybe_reco2dur=reco2dur # reco2dur indexed by utts + for x in feats.scp text segments utt2lang $maybe_wav; do + if [ -f $data/$x ]; then + utils/filter_scp.pl $data/$x $tmpdir/utts > $tmpdir/utts.tmp + mv $tmpdir/utts.tmp $tmpdir/utts + fi + done + [ ! -s $tmpdir/utts ] && echo "fix_data_dir.sh: no utterances remained: not proceeding further." && \ + rm $tmpdir/utts && exit 1; + + + if [ -f $data/utt2spk ]; then + new_nutts=$(cat $tmpdir/utts | wc -l) + old_nutts=$(cat $data/utt2spk | wc -l) + if [ $new_nutts -ne $old_nutts ]; then + echo "fix_data_dir.sh: kept $new_nutts utterances out of $old_nutts" + else + echo "fix_data_dir.sh: kept all $old_nutts utterances." + fi + fi + + for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $maybe_reco2dur $utt_extra_files; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then + utils/filter_scp.pl $tmpdir/utts $data/.backup/$x > $data/$x + fi + fi + done + +} + +filter_recordings +filter_speakers +filter_utts +filter_speakers +filter_recordings + +utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + +echo "fix_data_dir.sh: old files are kept in $data/.backup" diff --git a/egs/swbd/ASR/utils/parse_options.sh b/egs/swbd/ASR/utils/parse_options.sh new file mode 100755 index 000000000..34476fdb3 --- /dev/null +++ b/egs/swbd/ASR/utils/parse_options.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### No we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl b/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl new file mode 100755 index 000000000..23992f25d --- /dev/null +++ b/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl @@ -0,0 +1,27 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +while(<>){ + @A = split(" ", $_); + @A > 1 || die "Invalid line in spk2utt file: $_"; + $s = shift @A; + foreach $u ( @A ) { + print "$u $s\n"; + } +} + + diff --git a/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl b/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl new file mode 100755 index 000000000..6e0e438ca --- /dev/null +++ b/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl @@ -0,0 +1,38 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# converts an utt2spk file to a spk2utt file. +# Takes input from the stdin or from a file argument; +# output goes to the standard out. + +if ( @ARGV > 1 ) { + die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt"; +} + +while(<>){ + @A = split(" ", $_); + @A == 2 || die "Invalid line in utt2spk file: $_"; + ($u,$s) = @A; + if(!$seen_spk{$s}) { + $seen_spk{$s} = 1; + push @spklist, $s; + } + push (@{$spk_hash{$s}}, "$u"); +} +foreach $s (@spklist) { + $l = join(' ',@{$spk_hash{$s}}); + print "$s $l\n"; +} From ce08230adea2b2c5c45fd3e028cbfd914ffd8ec2 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 7 Oct 2023 11:57:30 +0800 Subject: [PATCH 056/113] Update README.md (#1293) --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 523203aa4..c89e7b9aa 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ We provide the following recipes: - [yesno][yesno] - [LibriSpeech][librispeech] - [GigaSpeech][gigaspeech] + - [AMI][ami] - [Aishell][aishell] - [Aishell2][aishell2] - [Aishell4][aishell4] @@ -37,6 +38,7 @@ We provide the following recipes: - [Aidatatang_200zh][aidatatang_200zh] - [WenetSpeech][wenetspeech] - [Alimeeting][alimeeting] + - [Switchboard][swbd] - [TAL_CSASR][tal_csasr] ### yesno @@ -393,4 +395,6 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [wenetspeech]: egs/wenetspeech/ASR [alimeeting]: egs/alimeeting/ASR [tal_csasr]: egs/tal_csasr/ASR +[ami]: egs/ami +[swbd]: egs/swbd/ASR [k2]: https://github.com/k2-fsa/k2 From fefffc02f68645dbbb2c0a54919c75f37da5dd4f Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 9 Oct 2023 17:39:23 +0800 Subject: [PATCH 057/113] Update optim.py (#1292) --- egs/librispeech/ASR/zipformer/optim.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index c9b76526c..8ee2b0eb4 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -491,6 +491,12 @@ class ScaledAdam(BatchedOptimizer): if self.show_dominant_parameters: assert p.shape[0] == len(param_names) self._show_gradient_dominating_parameter(tuples, tot_sumsq) + if ans != ans: # e.g. ans is nan + ans = 0.0 + if ans == 0.0: + for p, state, param_names in tuples: + p.grad.zero_() # get rid of infinity() + return ans def _show_gradient_dominating_parameter( @@ -573,7 +579,7 @@ class ScaledAdam(BatchedOptimizer): grad = p.grad if clipping_scale != 1.0: - grad = grad * clipping_scale + grad *= clipping_scale step = state["step"] delta = state["delta"] From 9af144c26b91065a119d4e67c03004974462d24d Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 9 Oct 2023 23:15:22 +0800 Subject: [PATCH 058/113] Zipformer update result (#1296) * update Zipformer results --- README.md | 6 +++--- egs/librispeech/ASR/RESULTS.md | 34 +++++++++++++++++++++------------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index c89e7b9aa..da446109d 100644 --- a/README.md +++ b/README.md @@ -120,9 +120,9 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles | Encoder | Params | test-clean | test-other | |-----------------|--------|------------|------------| -| zipformer | 65.5M | 2.21 | 4.91 | -| zipformer-small | 23.2M | 2.46 | 5.83 | -| zipformer-large | 148.4M | 2.11 | 4.77 | +| zipformer | 65.5M | 2.21 | 4.79 | +| zipformer-small | 23.2M | 2.42 | 5.73 | +| zipformer-large | 148.4M | 2.06 | 4.63 | Note: No auxiliary losses are used in the training and no LMs are used in the decoding. diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index b945f43fd..fc7fcdc26 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -75,7 +75,7 @@ See for more details. ##### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M The tensorboard log can be found at - + You can find a pretrained model, training logs, decoding logs, and decoding results at: @@ -90,18 +90,20 @@ You can use to deploy it. | greedy_search | 2.23 | 4.96 | --epoch 40 --avg 16 | | modified_beam_search | 2.21 | 4.91 | --epoch 40 --avg 16 | | fast_beam_search | 2.24 | 4.93 | --epoch 40 --avg 16 | +| greedy_search | 2.22 | 4.87 | --epoch 50 --avg 25 | +| modified_beam_search | 2.21 | 4.79 | --epoch 50 --avg 25 | +| fast_beam_search | 2.21 | 4.82 | --epoch 50 --avg 25 | | modified_beam_search_shallow_fusion | 2.01 | 4.37 | --epoch 40 --avg 16 --beam-size 12 --lm-scale 0.3 | | modified_beam_search_LODR | 1.94 | 4.17 | --epoch 40 --avg 16 --beam-size 12 --lm-scale 0.52 --LODR-scale -0.26 | | modified_beam_search_rescore | 2.04 | 4.39 | --epoch 40 --avg 16 --beam-size 12 | | modified_beam_search_rescore_LODR | 2.01 | 4.33 | --epoch 40 --avg 16 --beam-size 12 | - The training command is: ```bash export CUDA_VISIBLE_DEVICES="0,1,2,3" ./zipformer/train.py \ --world-size 4 \ - --num-epochs 40 \ + --num-epochs 50 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ @@ -115,8 +117,8 @@ The decoding command is: export CUDA_VISIBLE_DEVICES="0" for m in greedy_search modified_beam_search fast_beam_search; do ./zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ + --epoch 50 \ + --avg 25 \ --use-averaged-model 1 \ --exp-dir ./zipformer/exp \ --max-duration 600 \ @@ -129,7 +131,7 @@ To decode with external language models, please refer to the documentation [here ##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M The tensorboard log can be found at - + You can find a pretrained model, training logs, decoding logs, and decoding results at: @@ -144,13 +146,16 @@ You can use to deploy it. | greedy_search | 2.49 | 5.91 | --epoch 40 --avg 13 | | modified_beam_search | 2.46 | 5.83 | --epoch 40 --avg 13 | | fast_beam_search | 2.46 | 5.87 | --epoch 40 --avg 13 | +| greedy_search | 2.46 | 5.86 | --epoch 50 --avg 23 | +| modified_beam_search | 2.42 | 5.73 | --epoch 50 --avg 23 | +| fast_beam_search | 2.46 | 5.78 | --epoch 50 --avg 23 | The training command is: ```bash export CUDA_VISIBLE_DEVICES="0,1" ./zipformer/train.py \ --world-size 2 \ - --num-epochs 40 \ + --num-epochs 50 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp-small \ @@ -169,8 +174,8 @@ The decoding command is: export CUDA_VISIBLE_DEVICES="0" for m in greedy_search modified_beam_search fast_beam_search; do ./zipformer/decode.py \ - --epoch 40 \ - --avg 13 \ + --epoch 50 \ + --avg 23 \ --exp-dir zipformer/exp-small \ --max-duration 600 \ --causal 0 \ @@ -185,7 +190,7 @@ done ##### large-scaled model, number of model parameters: 148439574, i.e., 148.4 M The tensorboard log can be found at - + You can find a pretrained model, training logs, decoding logs, and decoding results at: @@ -200,13 +205,16 @@ You can use to deploy it. | greedy_search | 2.12 | 4.8 | --epoch 40 --avg 13 | | modified_beam_search | 2.11 | 4.7 | --epoch 40 --avg 13 | | fast_beam_search | 2.13 | 4.78 | --epoch 40 --avg 13 | +| greedy_search | 2.08 | 4.69 | --epoch 50 --avg 30 | +| modified_beam_search | 2.06 | 4.63 | --epoch 50 --avg 30 | +| fast_beam_search | 2.09 | 4.68 | --epoch 50 --avg 30 | The training command is: ```bash export CUDA_VISIBLE_DEVICES="0,1,2,3" ./zipformer/train.py \ --world-size 4 \ - --num-epochs 40 \ + --num-epochs 50 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp-large \ @@ -224,8 +232,8 @@ The decoding command is: export CUDA_VISIBLE_DEVICES="0" for m in greedy_search modified_beam_search fast_beam_search; do ./zipformer/decode.py \ - --epoch 40 \ - --avg 16 \ + --epoch 50 \ + --avg 30 \ --exp-dir zipformer/exp-large \ --max-duration 600 \ --causal 0 \ From 0d09a449303eb899eb729238436c3766a0a328b5 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 11 Oct 2023 10:06:00 +0800 Subject: [PATCH 059/113] Update train.py (#1299) --- egs/aishell/ASR/pruned_transducer_stateless7/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py index 11671db92..9d9dd4288 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py @@ -703,7 +703,7 @@ def compute_loss( if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training From 103d617380c5a49599fcc7fd713c69d861989453 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 11 Oct 2023 11:04:20 +0800 Subject: [PATCH 060/113] bug fixes (#1301) --- egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py | 6 +++--- egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py index 1b6991bcd..2f8e658c5 100644 --- a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py +++ b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py @@ -32,7 +32,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples @@ -230,8 +230,8 @@ class LibriSpeechAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py index 59d73c660..aeeb2ef78 100644 --- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -32,7 +32,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -302,8 +302,8 @@ class SwitchBoardAsrDataModule: buffer_size=50000, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, From cb874e99055c3d62d199ce8d296bb118f3d8aa23 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 11 Oct 2023 12:20:12 +0800 Subject: [PATCH 061/113] add export-onnx.py for stateless8 (#1302) * add export-onnx.py for stateless8 * use tokens.txt to replace bpe.model --- .../export-onnx.py | 604 ++++++++++++++++++ 1 file changed, 604 insertions(+) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless8/export-onnx.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export-onnx.py new file mode 100755 index 000000000..3fef231a1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export-onnx.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/WeijiZhuang/icefall-asr-librispeech-pruned-transducer-stateless8-2022-12-02 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/WeijiZhuang/icefall-asr-librispeech-pruned-transducer-stateless8-2022-12-02 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless8/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import Zipformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import num_tokens, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless7", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + # is defined in local/train_bpe_model.py + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() From 16a2748d6cc0eed7f08d034e835a4deef8565d73 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 11 Oct 2023 14:56:41 +0800 Subject: [PATCH 062/113] PromptASR for contextualized ASR with controllable style (#1250) * Add PromptASR with BERT as text encoder * Support using word-list based content prompts for context biasing * Upload the pretrained models to huggingface * Add usage example --- egs/libriheavy/ASR/RESULTS.md | 205 ++ egs/libriheavy/ASR/prepare_prompt_asr.sh | 36 + egs/libriheavy/ASR/shared | 1 + .../ASR/zipformer_prompt_asr/__init__.py | 0 .../zipformer_prompt_asr/asr_datamodule.py | 520 ++++ .../ASR/zipformer_prompt_asr/beam_search.py | 1 + .../ASR/zipformer_prompt_asr/dataset.py | 586 +++++ .../zipformer_prompt_asr/decode_baseline.py | 791 ++++++ .../ASR/zipformer_prompt_asr/decode_bert.py | 1025 ++++++++ ...decode_bert_with_style_save_decoding_mp.py | 963 +++++++ .../ASR/zipformer_prompt_asr/decoder.py | 130 + .../zipformer_prompt_asr/encoder_interface.py | 43 + .../zipformer_prompt_asr/export_PromptASR.py | 255 ++ .../ASR/zipformer_prompt_asr/joiner.py | 86 + .../ls_text_normalization.py | 153 ++ .../zipformer_prompt_asr/model_baseline.py | 262 ++ .../zipformer_prompt_asr/model_with_BERT.py | 392 +++ .../ASR/zipformer_prompt_asr/optim.py | 1168 +++++++++ .../ASR/zipformer_prompt_asr/pretrained.py | 359 +++ .../ASR/zipformer_prompt_asr/scaling.py | 1872 +++++++++++++ .../ASR/zipformer_prompt_asr/subsampling.py | 276 ++ .../ASR/zipformer_prompt_asr/test_model.py | 119 + .../text_normalization.py | 101 + .../zipformer_prompt_asr/train_baseline.py | 1390 ++++++++++ .../train_bert_encoder.py | 1798 +++++++++++++ .../zipformer_prompt_asr/transcribe_bert.py | 515 ++++ .../ASR/zipformer_prompt_asr/utils.py | 439 ++++ .../ASR/zipformer_prompt_asr/zipformer.py | 2310 +++++++++++++++++ icefall/utils.py | 32 +- 29 files changed, 15825 insertions(+), 3 deletions(-) create mode 100644 egs/libriheavy/ASR/RESULTS.md create mode 100755 egs/libriheavy/ASR/prepare_prompt_asr.sh create mode 120000 egs/libriheavy/ASR/shared create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/__init__.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py create mode 120000 egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py create mode 100755 egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py create mode 100755 egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/ls_text_normalization.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/optim.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py create mode 100755 egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py create mode 100755 egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/utils.py create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py diff --git a/egs/libriheavy/ASR/RESULTS.md b/egs/libriheavy/ASR/RESULTS.md new file mode 100644 index 000000000..4fbedad98 --- /dev/null +++ b/egs/libriheavy/ASR/RESULTS.md @@ -0,0 +1,205 @@ +## Results + +### Zipformer PromptASR (zipformer + PromptASR + BERT text encoder) + +#### [zipformer_prompt_asr](./zipformer_prompt_asr) + +See for commit history and +our paper for more details. + + + +##### Training on the medium subset, with content & style prompt, **no** context list + +You can find a pre-trained model, training logs, decoding logs, and decoding results at: + +The training command is: + +```bash +causal=0 +subset=medium +memory_dropout_rate=0.05 +text_encoder_type=BERT + +python ./zipformer_prompt_asr/train_bert_encoder.py \ + --world-size 4 \ + --start-epoch 1 \ + --num-epochs 60 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --use-fp16 True \ + --memory-dropout-rate $memory_dropout_rate \ + --causal $causal \ + --subset $subset \ + --manifest-dir data/fbank \ + --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ + --max-duration 1000 \ + --text-encoder-type $text_encoder_type \ + --text-encoder-dim 768 \ + --use-context-list 0 \ + --top-k $top_k \ + --use-style-prompt 1 +``` + +The decoding results using utterance-level context (epoch-60-avg-10): + +| decoding method | lh-test-clean | lh-test-other | comment | +|----------------------|---------------|---------------|---------------------| +| modified_beam_search | 3.13 | 6.78 | --use-pre-text False --use-style-prompt False | +| modified_beam_search | 2.86 | 5.93 | --pre-text-transform upper-no-punc --style-text-transform upper-no-punc | +| modified_beam_search | 2.6 | 5.5 | --pre-text-transform mixed-punc --style-text-transform mixed-punc | + + +The decoding command is: + +```bash +for style in mixed-punc upper-no-punc; do + python ./zipformer_prompt_asr/decode_bert.py \ + --epoch 60 \ + --avg 10 \ + --use-averaged-model True \ + --post-normalization True \ + --causal False \ + --exp-dir ./zipformer_prompt_asr/exp \ + --manifest-dir data/fbank \ + --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --beam-size 4 \ + --text-encoder-type BERT \ + --text-encoder-dim 768 \ + --memory-layer 0 \ + --use-ls-test-set False \ + --use-ls-context-list False \ + --max-prompt-lens 1000 \ + --use-pre-text True \ + --use-style-prompt True \ + --style-text-transform $style \ + --pre-text-transform $style \ + --compute-CER 0 +done +``` + +##### Training on the medium subset, with content & style prompt, **with** context list + +You can find a pre-trained model, training logs, decoding logs, and decoding results at: + +This model is trained with an extra type of content prompt (context words), thus it does better +on **word-level** context biasing. Note that to train this model, please first run `prepare_prompt_asr.sh` +to prepare a manifest containing context words. + +The training command is: + +```bash + +causal=0 +subset=medium +memory_dropout_rate=0.05 +text_encoder_type=BERT +use_context_list=True + +# prepare the required data for context biasing +./prepare_prompt_asr.sh --stage 0 --stop_stage 1 + +python ./zipformer_prompt_asr/train_bert_encoder.py \ + --world-size 4 \ + --start-epoch 1 \ + --num-epochs 50 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --use-fp16 True \ + --memory-dropout-rate $memory_dropout_rate \ + --causal $causal \ + --subset $subset \ + --manifest-dir data/fbank \ + --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ + --max-duration 1000 \ + --text-encoder-type $text_encoder_type \ + --text-encoder-dim 768 \ + --use-context-list $use_context_list \ + --top-k 10000 \ + --use-style-prompt 1 +``` + +*Utterance-level biasing:* + +| decoding method | lh-test-clean | lh-test-other | comment | +|----------------------|---------------|---------------|---------------------| +| modified_beam_search | 3.17 | 6.72 | --use-pre-text 0 --use-style-prompt 0 | +| modified_beam_search | 2.91 | 6.24 | --pre-text-transform upper-no-punc --style-text-transform upper-no-punc | +| modified_beam_search | 2.72 | 5.72 | --pre-text-transform mixed-punc --style-text-transform mixed-punc | + + +The decoding command for the table above is: + +```bash +for style in mixed-punc upper-no-punc; do + python ./zipformer_prompt_asr/decode_bert.py \ + --epoch 50 \ + --avg 10 \ + --use-averaged-model True \ + --post-normalization True \ + --causal False \ + --exp-dir ./zipformer_prompt_asr/exp \ + --manifest-dir data/fbank \ + --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --beam-size 4 \ + --text-encoder-type BERT \ + --text-encoder-dim 768 \ + --memory-layer 0 \ + --use-ls-test-set False \ + --use-ls-context-list False \ + --max-prompt-lens 1000 \ + --use-pre-text True \ + --use-style-prompt True \ + --style-text-transform $style \ + --pre-text-transform $style \ + --compute-CER 0 +done +``` + +*Word-level biasing:* + +The results are reported on LibriSpeech test-sets using the biasing list provided from . +You need to set `--use-ls-test-set True` so that the LibriSpeech test sets are used. + +| decoding method | ls-test-clean | ls-test-other | comment | +|----------------------|---------------|---------------|---------------------| +| modified_beam_search | 2.4 | 5.08 | --use-pre-text 0 --use-style-prompt 0 | +| modified_beam_search | 2.14 | 4.62 | --use-ls-context-list 1 --pre-text-transform mixed-punc --style-text-transform mixed-punc --ls-distractors 0 | +| modified_beam_search | 2.14 | 4.64 | --use-ls-context-list 1 --pre-text-transform mixed-punc --style-text-transform mixed-punc --ls-distractors 100 | + +The decoding command is for the table above is: + +```bash +use_ls_test_set=1 +use_ls_context_list=1 + +for ls_distractors in 0 100; do + python ./zipformer_prompt_asr/decode_bert.py \ + --epoch 50 \ + --avg 10 \ + --use-averaged-model True \ + --post-normalization True \ + --causal False \ + --exp-dir ./zipformer_prompt_asr/exp \ + --manifest-dir data/fbank \ + --bpe-model data/lang_bpe_500_fallback_coverage_0.99/bpe.model \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --beam-size 4 \ + --text-encoder-type BERT \ + --text-encoder-dim 768 \ + --memory-layer 0 \ + --use-ls-test-set $use_ls_test_setse \ + --use-ls-context-list $use_ls_context_list \ + --ls-distractors $ls_distractors \ + --max-prompt-lens 1000 \ + --use-pre-text True \ + --use-style-prompt True \ + --style-text-transform mixed-punc \ + --pre-text-transform mixed-punc \ + --compute-CER 0 +done + +``` diff --git a/egs/libriheavy/ASR/prepare_prompt_asr.sh b/egs/libriheavy/ASR/prepare_prompt_asr.sh new file mode 100755 index 000000000..b931cea26 --- /dev/null +++ b/egs/libriheavy/ASR/prepare_prompt_asr.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +set -eou pipefail + +# This is the preparation recipe for PromptASR: https://arxiv.org/pdf/2309.07414 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +stage=-1 +stop_stage=100 +manifest_dir=data/fbank +subset=medium +topk=10000 + +. shared/parse_options.sh || exit 1 + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download the meta biasing list for LibriSpeech" + mkdir -p data/context_biasing + cd data/context_biasing + git clone https://github.com/facebookresearch/fbai-speech.git + cd ../.. +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Add rare-words for context biasing to the manifest" + python zipformer_prompt_asr/utils.py \ + --manifest-dir $manifest_dir \ + --subset $subset \ + --top-k $topk + +fi diff --git a/egs/libriheavy/ASR/shared b/egs/libriheavy/ASR/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/libriheavy/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/__init__.py b/egs/libriheavy/ASR/zipformer_prompt_asr/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py new file mode 100644 index 000000000..690003377 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py @@ -0,0 +1,520 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +import torch +from dataset import PromptASRDataset +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # SingleCutSampler, + CutConcatenate, + CutMix, + DynamicBucketingSampler, + ExtraPadding, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriHeavyAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + if args.use_context_list: + assert args.rare_word_file is not None + with open(args.rare_word_file, "r") as f: + self.rare_word_list = ( + f.read().lower().split() + ) # Use lower-cased for easier style transform + else: + self.rare_word_list = None + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", + ) + + # Libriheavy specific arguments + group.add_argument( + "--subset", + type=str, + default="small", + help="Select the Libriheavy subset (small|medium|large)", + ) + + group.add_argument( + "--use-context-list", + type=str2bool, + default=False, + help="Use the context list of libri heavy", + ) + + group.add_argument( + "--top-k", + type=int, + default=10000, + help="""The top-k words are identified as common words, + the rest as rare words""", + ) + + group.add_argument( + "--with-decoding", + type=str2bool, + default=False, + help="If the texts field contain decoding", + ) + + group.add_argument( + "--random-left-padding", + type=str2bool, + ) + + group.add_argument( + "--rare-word-file", + type=str, + ) + + group.add_argument( + "--long-audio-cuts", + type=str, + default="data/manifest_npr/npr1_cuts_all_guids_0.jsonl.gz", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + text_sampling_func: Callable[[List[str]], str] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = PromptASRDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + text_sampling_func=text_sampling_func, + rare_word_list=self.rare_word_list, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = PromptASRDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + text_sampling_func=text_sampling_func, + rare_word_list=self.rare_word_list, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + else: + raise NotImplementedError( + "SingleCutSampler is no longer supported by lhotse" + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + text_sampling_func: Callable[[List[str]], str] = None, + ) -> DataLoader: + transforms = [] + if self.args.random_left_padding: + logging.info("Enable random left padding") + transforms.append( + ExtraPadding(extra_frames=16, randomized=True, direction="left") + ) + + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = PromptASRDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + text_sampling_func=text_sampling_func, + rare_word_list=self.rare_word_list, + ) + else: + validate = PromptASRDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + text_sampling_func=text_sampling_func, + rare_word_list=self.rare_word_list, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info(f"About to get {self.args.subset} cuts") + + if self.args.use_context_list: + path = ( + self.args.manifest_dir + / f"libriheavy_cuts_{self.args.subset}_with_context_list_topk_{self.args.top_k}.jsonl.gz" + ) + elif self.args.with_decoding: + path = ( + self.args.manifest_dir + / f"libriheavy_cuts_{self.args.subset}_with_decoding.jsonl.gz" + ) + else: + path = ( + self.args.manifest_dir / f"libriheavy_cuts_{self.args.subset}.jsonl.gz" + ) + + logging.info(f"Loading manifest from {path}.") + cuts_train = CutSet.from_jsonl_lazy(path) + return cuts_train + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz" + ) + return cuts_valid + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test-clean_official.jsonl.gz" + ) + return cuts_valid + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test-other_official.jsonl.gz" + ) + return cuts_valid + + @lru_cache() + def librispeech_test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def librispeech_test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def long_audio_cuts(self) -> CutSet: + logging.info("About to get long audio cuts") + cuts = load_manifest_lazy( + self.args.long_audio_cuts, + ) + return cuts + + @lru_cache() + def test_dev_cuts(self) -> CutSet: + logging.info("About to get test dev cuts") + cuts = load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test_dev.jsonl.gz" + ) + return cuts diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py b/egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py b/egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py new file mode 100644 index 000000000..e0bf8f73d --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py @@ -0,0 +1,586 @@ +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset import K2SpeechRecognitionDataset +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import compute_num_frames, ifnone +from text_normalization import ( + lower_all_char, + lower_only_alpha, + remove_non_alphabetic, + train_text_normalization, + upper_all_char, + upper_only_alpha, +) +from torch.utils.data.dataloader import DataLoader, default_collate + + +class PromptASRDataset(torch.utils.data.Dataset): + """This is a dataset for Prompt ASR. It supports the following features: + 1. Select a tuple of (text, pre_text, style_text) randomly from a + list of texts as supervisions. + + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + text_sampling_func: Optional[Callable[[List[str]], str]] = None, + rare_word_list: Optional[List[str]] = None, + ): + """ + Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py + for more details. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + :param text_sampling_func: Sampling a text as transcription from a list of texts. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + # a text sampling function + self.text_sampling_func = text_sampling_func + self.rare_word_list = rare_word_list + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_frames and max_cuts. + """ + validate_for_asr(cuts) + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts + # the supervision boundaries. + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + # Sort the cuts again after transforms + cuts = cuts.sort_by_duration(ascending=False) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we succesfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + batch = { + "inputs": inputs, + "supervisions": default_collate( + [ + self.text_sampling_func( + texts=supervision.texts, + pre_texts=supervision.pre_texts, + context_list=supervision.context_list + if "context_list" in supervision.custom + else None, + rare_word_list=self.rare_word_list, + ) + if self.text_sampling_func is not None + else { + "text": train_text_normalization(supervision.texts[0]), + "pre_text": train_text_normalization(supervision.pre_texts[0]), + "style_text": train_text_normalization( + supervision.pre_texts[0] + ), + "transform_ids": 0, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + has_word_alignments = all( + s.alignment is not None and "word" in s.alignment + for c in cuts + for s in c.supervisions + ) + + return batch + + +def validate_for_asr(cuts: CutSet) -> None: + validate(cuts) + tol = 2e-3 # 1ms + for cut in cuts: + for supervision in cut.supervisions: + assert supervision.start >= -tol, ( + f"Supervisions starting before the cut are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + # + # 'supervision.end' is end of supervision inside the Cut + assert supervision.end <= cut.duration + tol, ( + f"Supervisions ending after the cut " + f"are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) + + +def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str: + """A helper function that generates a random substring from a given string + + Args: + s (str): Input string + + Returns: + str: Returned substring + """ + min_len = min(len(s), min_len) + + start = random.randint(0, len(s) - min_len) + end = min(start + max_len, random.randint(start + min_len, len(s))) + + return s[start:end] + + +def triplet_text_sampling( + texts: List[str], + pre_texts: List[str], + context_list: Optional[str] = None, + rare_word_list: Optional[List[str]] = None, + transforms: Optional[List[Callable[[str], str]]] = None, + min_len_style: Optional[int] = 80, +) -> Dict[str, str]: + """This function generates a triplet of + (pre_text, style_text, ref_text). The style of style_text and ref_text + should **always** match, whereas the style of pre_text is arbitrary. + Suppose we have 2 different transforms A,B, and the preceding text is + referred to as pre_text. The following three tuples are all valid: + + (A(pre_text), A(style_text), A(ref_text)) + (A(pre_text), B(style_text), B(ref_text)) + (A(pre_text), A(style_text), A(ref_text)) + (B(pre_text), B(style_text), B(ref_text)) + + If transforms is not given, the following pre-defined transforms + are available: + 0: original (mixed-cased, with punc) + 1: upper_only_alpha (upper-cased, no punc) + + When the transform of text and pre_text match, we can use the whole + pre_text as the prompt text. + + Args: + texts (List[str]): + A list of ref_texts whose first item is the ground truth + text from books. + pre_texts (List[str]): + A list of pre_texts, whose first item is the groundtruth + pre_text from books. + context_list: Optional[str] = None, + A list of biasing words separated by space + rare_word_list: Optional[str] = None, + A list of rare-words separated by space (used as distractors) + transforms (List[Callable[[str], str]]): A list of possible transforms to be applied + + Returns: + A dictionary of ref_text, pre_text, style_text + """ + assert len(texts) == len(pre_texts) + assert len(texts) == 2 + + # we assume the first item to be ground truth + gt_text = texts[0] + gt_pre_text = pre_texts[0] + + if transforms is None: + transforms = [ + lambda x: x, # return it self + upper_only_alpha, + lower_only_alpha, + lower_all_char, + ] + + sampling_weight = [ + 0.7, + 0.3, + 0.0, + 0.0, + ] # Mixed-punc should have the largest sampling prob + + total_transforms = len(transforms) # do not use the recognized trans + + # Randomly sample transforms + i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight) + + # get the normalized text and pre_text + text = transforms[i_text](gt_text) + pre_text = transforms[i_pre_text](gt_pre_text) + + if i_text == i_pre_text: + style_text = get_substring(pre_text, min_len=min_len_style, max_len=150) + else: + # get the pre_text of same style as text + # For now, **don't** do transform to the style text, because we do it after the dataloader + style_text = gt_pre_text + # style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text) + style_text = get_substring(style_text, min_len=min_len_style, max_len=150) + + return { + "text": train_text_normalization(text), + "pre_text": train_text_normalization(pre_text), + "style_text": train_text_normalization(style_text), + "transform_ids": i_text, + } + + +def triplet_text_sampling_with_context_list( + texts: List[str], + pre_texts: List[str], + context_list: str, + rare_word_list: List[str], + transforms: Optional[List[Callable[[str], str]]] = None, + min_len_style: Optional[int] = 80, +) -> Dict[str, str]: + """This function generates a triplet of + (pre_text, style_text, ref_text). The pre_text is either the preceding text + or a list of words (context words + distractors). + The style of style_text and ref_text should **always** match, whereas + the style of pre_text is arbitrary. + Suppose we have 2 different transforms A,B, and the preceding text is + referred to as pre_text. The following three tuples are all valid: + + (A(pre_text), A(style_text), A(ref_text)) + (A(pre_text), B(style_text), B(ref_text)) + (A(pre_text), A(style_text), A(ref_text)) + (B(pre_text), B(style_text), B(ref_text)) + + If transforms is not given, the following pre-defined transforms + are available: + 0: original (mixed-cased, with punc) + 1: upper_only_alpha (upper-cased, no punc) + + When the transform of text and pre_text match, we can use the whole + pre_text as the prompt text. + + Args: + texts (List[str]): + A list of ref_texts whose first item is the ground truth + text from books. + pre_texts (List[str]): + A list of pre_texts, whose first item is the groundtruth + pre_text from books. + context_list: Optional[str] = None, + A list of biasing words separated by space + rare_word_list: Optional[str] = None, + A list of rare-words separated by space (used as distractors) + transforms (List[Callable[[str], str]]): A list of possible transforms to be applied + + Returns: + A dictionary of ref_text, pre_text, style_text + Returns: + str: A dictionary + """ + # import pdb; pdb.set_trace() + assert len(texts) == len(pre_texts) + assert len(texts) == 2 + + if context_list is not None: + context_list = context_list.lower() + + # we assume the first item to be ground truth + gt_text = texts[0] + gt_pre_text = pre_texts[0] + + if transforms is None: + transforms = [ + lambda x: x, # return it self + upper_only_alpha, + lower_only_alpha, + lower_all_char, + ] + + sampling_weight = [ + 0.7, + 0.3, + 0.0, + 0.0, + ] # Mixed-punc should have the largest sampling prob + + total_transforms = len(transforms) # do not use the recognized trans + + # Select a transformation randomly + i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight) + + # get the normalized text and pre_text + text = transforms[i_text](gt_text) + pre_text = get_pre_text_with_context_list2( + text=gt_text, + pre_text=gt_pre_text, + context_list=context_list, + rare_words_list=rare_word_list, + ) + pre_text = transforms[i_pre_text](pre_text) + + if i_text == i_pre_text: + style_text = get_substring(pre_text, min_len=min_len_style, max_len=150) + else: + # get the pre_text of same style as text + # For now, **don't** do transform to the style text + style_text = gt_pre_text + # style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text) + style_text = get_substring(style_text, min_len=min_len_style, max_len=150) + + return { + "text": train_text_normalization(text), + "pre_text": train_text_normalization(pre_text), + "style_text": train_text_normalization(style_text), + "transform_ids": i_text, + } + + +def get_pre_text_with_context_list( + text: str, + pre_text: str, + context_list: str, + rare_words_list: List[str] = None, +) -> str: + # Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha + # By a small proportion of time, use the substring of ref_text as pre_text + + if context_list != "" and context_list is not None: + v = random.random() + if v < 0.5: + # correct + distractors + # sample distractors + num_distractors = random.randint(0, 50) + distractors = random.sample(rare_words_list, num_distractors) + # sample correct + correct = context_list.split() + i = random.randint(1, len(correct)) + correct = random.sample(correct, i) + # combine correct and distractors + pre_text = distractors + correct + random.shuffle(pre_text) + pre_text = " ".join(pre_text) + elif v < 0.7: + splitted = text.split() + sampling_weights = [len(w) ** 1.2 for w in splitted] + sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] + i = random.randint(1, min(len(splitted), 20)) + splitted = list(np.random.choice(splitted, i, p=sampling_weights)) + num_distractors = random.randint(0, 70) + distractors = random.sample(rare_words_list, num_distractors) + splitted += distractors + random.shuffle(splitted) # shuffle the list + pre_text = " ".join(splitted) + else: + pre_text = pre_text + else: + v = random.random() + if v < 0.1: + splitted = text.split() + sampling_weights = [len(w) ** 1.2 for w in splitted] + sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] + i = random.randint(1, min(len(splitted), 20)) + splitted = list(np.random.choice(splitted, i, p=sampling_weights)) + pre_text = " ".join(splitted) + num_distractors = random.randint(0, 70) + distractors = random.sample(rare_words_list, num_distractors) + splitted += distractors + random.shuffle(splitted) # shuffle the list + elif v < 0.2: + # full distractors + num_distractors = random.randint(5, 100) + distractors = random.sample(rare_words_list, num_distractors) + pre_text = " ".join(distractors) + + elif v < 0.3: + pre_text = get_substring(text, min_len=15, max_len=150) + else: + pre_text = pre_text + + return pre_text + + +def get_pre_text_with_context_list2( + text: str, + pre_text: str, + context_list: str, + rare_words_list: List[str] = None, +) -> str: + # Get the pre_text, either the ground truth preceding text or + # a list of words consisting of biasing words and distrators + # By a small proportion of time, use the substring of ref_text as pre_text + + if context_list != "" and context_list is not None: + v = random.random() + if v < 0.4: + # sample distractors + num_distractors = random.randint(50, 100) + distractors = random.sample(rare_words_list, num_distractors) + # sample correct + correct = context_list.split() + i = random.randint(1, len(correct)) + correct = random.sample(correct, i) + # combine correct and distractors + pre_text = distractors + correct + random.shuffle(pre_text) + pre_text = " ".join(pre_text) + elif v < 0.55: + splitted = text.split() + sampling_weights = [ + len(w) ** 1.2 for w in splitted + ] # longer words with higher weights + sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] + i = random.randint(1, min(len(splitted), 20)) + splitted = list(np.random.choice(splitted, i, p=sampling_weights)) + num_distractors = random.randint(50, 100) + distractors = random.sample(rare_words_list, num_distractors) + splitted += distractors + random.shuffle(splitted) # shuffle the list + pre_text = " ".join(splitted) + else: + pre_text = pre_text + else: + v = random.random() + if v < 0.3: + splitted = text.split() + sampling_weights = [len(w) ** 1.2 for w in splitted] + sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] + i = random.randint(1, min(len(splitted), 20)) + splitted = list(np.random.choice(splitted, i, p=sampling_weights)) + pre_text = " ".join(splitted) + num_distractors = random.randint(50, 100) + distractors = random.sample(rare_words_list, num_distractors) + splitted += distractors + random.shuffle(splitted) # shuffle the list + elif v < 0.4: + # full distractors + num_distractors = random.randint(5, 100) + distractors = random.sample(rare_words_list, num_distractors) + pre_text = " ".join(distractors) + elif v < 0.6: + pre_text = get_substring(text, min_len=15, max_len=150) + else: + pre_text = pre_text + + return pre_text + + +def naive_triplet_text_sampling( + texts: List[str], + pre_texts: List[str], + context_list: str = None, + rare_word_list: List[str] = None, + min_len_style: Optional[int] = 120, +): + # The most simplest text sampling function, used only for + # evaluation, use a fixed sentence as the style text + + return { + "text": train_text_normalization(texts[0]), + "pre_text": train_text_normalization(pre_texts[0]), + "style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?", + "transform_ids": 0, + } + + +def random_shuffle_subset( + data: List[str], + p: float = 0.2, + p_mask: float = 0.05, +) -> List[str]: + """ + Randomly shuffle the subset by probability `p`, which means that p% of the samples + in the original batch are shuffled, the others are kept in the original order. + + With a probability of `p_mask`, replace the original string with an empty string. + + """ + + num_to_shuffle = int(len(data) * p) + id_to_shuffle = np.random.choice(len(data), num_to_shuffle, replace=False) + item_to_shuffle = [data[id] for id in id_to_shuffle] + random.shuffle(item_to_shuffle) + + for id, item in zip(id_to_shuffle, item_to_shuffle): + data[id] = item + + # Randomly mask a proportion of the data to empty string + if p_mask > 0: + for i in range(len(data)): + if random.random() < p_mask: + data[i] = "" + + return data + + +if __name__ == "__main__": + texts = [ + "AA, BB, cC, dD!", + "AA BB CC DD", + ] + + pre_texts = [ + "EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?", + "EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG", + ] + for i in range(10): + print(f"Run: {i}") + print(triplet_text_sampling(texts, pre_texts)) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py new file mode 100644 index 000000000..6a3bab3c8 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py @@ -0,0 +1,791 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +""" + + +import argparse +import logging +import math +import warnings +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from beam_search import greedy_search, greedy_search_batch, modified_beam_search +from ls_text_normalization import word_normalization +from text_normalization import ( + ref_text_normalization, + remove_non_alphabetic, + upper_only_alpha, +) +from train_baseline import add_model_arguments, get_params, get_transducer_model +from utils import write_error_stats + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_LODR + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--post-normalization", + type=str2bool, + default=True, + help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", + ) + + parser.add_argument( + "--long-audio-recog", + type=str2bool, + default=False, + ) + + parser.add_argument( + "--use-ls-test-set", + type=str2bool, + default=False, + help="Use librispeech test set for evaluation.", + ) + + parser.add_argument( + "--compute-CER", + type=str2bool, + default=True, + help="Reports CER. By default, only reports WER", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + texts = batch["supervisions"]["text"] + batch_size = feature.size(0) + + # Get the transducer encoder output + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + encoder_out, encoder_out_lens = model.encode_audio( + feature=feature, + feature_lens=feature_lens, + ) + + hyps = [] + + if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + if not params.use_ls_test_set: + book_names = [ + cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"] + ] + else: + book_names = ["" for _ in cut_ids] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, book_name, hyp_words, ref_text in zip( + cut_ids, book_names, hyps, texts + ): + ref_text = ref_text_normalization(ref_text) + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + # if not params.use_ls_test_set: + # results[name + " " + book_name].extend(this_batch) + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + biasing_words: List[str] = None, +): + test_set_wers = dict() + test_set_cers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + biasing_words=biasing_words, + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + if params.compute_CER: + # Write CER statistics + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results, char_level=True) + errs_filename = ( + params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + cer = write_error_stats( + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=params.compute_CER, + ) + test_set_cers[key] = cer + + logging.info("Wrote detailed CER stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + if params.compute_CER: + test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tcER", file=f) + for key, val in test_set_cers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_cers: + s += "{} CER\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "modified_beam_search", + ) + + if params.long_audio_recog: + params.res_dir = params.exp_dir / (params.decoding_method + "long_audio") + else: + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to(device) + model.eval() + + LM = None + + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + libriheavy = LibriHeavyAsrDataModule(args) + + test_clean_cuts = libriheavy.test_clean_cuts() + test_other_cuts = libriheavy.test_other_cuts() + ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts() + ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() + long_audio_cuts = libriheavy.long_audio_cuts() + + test_clean_dl = libriheavy.valid_dataloaders( + test_clean_cuts, + ) + test_other_dl = libriheavy.valid_dataloaders( + test_other_cuts, + ) + ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) + ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) + long_audio_dl = libriheavy.valid_dataloaders( + long_audio_cuts, + ) + + if params.use_ls_test_set: + test_sets = ["ls-test-clean", "ls-test-other"] + test_dl = [ls_test_clean_dl, ls_test_other_dl] + else: + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + if params.long_audio_recog: + test_sets = ["long-audio"] + test_dl = [long_audio_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + if params.use_ls_test_set: + f = open( + "data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", "r" + ) + biasing_words = f.read().strip().split() + f.close() + else: + biasing_words = None + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if params.post_normalization: + if "-post-normalization" not in params.suffix: + params.suffix += "-post-normalization" + + new_res = {} + for k in results_dict: + new_ans = [] + for item in results_dict[k]: + id, ref, hyp = item + if params.use_ls_test_set: + hyp = ( + " ".join(hyp).replace("-", " ").split() + ) # handle the hypens + hyp = upper_only_alpha(" ".join(hyp)).split() + hyp = [word_normalization(w.upper()) for w in hyp] + hyp = " ".join(hyp).split() + hyp = [w for w in hyp if w != ""] + ref = upper_only_alpha(" ".join(ref)).split() + else: + hyp = upper_only_alpha(" ".join(hyp)).split() + ref = upper_only_alpha(" ".join(ref)).split() + new_ans.append((id, ref, hyp)) + new_res[k] = new_ans + + save_results( + params=params, + test_set_name=test_set, + results_dict=new_res, + biasing_words=biasing_words, + ) + + if params.suffix.endswith("-post-normalization"): + params.suffix = params.suffix.replace("-post-normalization", "") + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py new file mode 100755 index 000000000..e71999b0a --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py @@ -0,0 +1,1025 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer_prompt_asr/decode_bert.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --decoding-method greedy_search \ + --text-encoder-type BERT \ + --memory-layer 0 \ + --use-pre-text True \ + --use-style-prompt True \ + --max-prompt-lens 1000 \ + --style-text-transform mixed-punc \ + --pre-text-transform mixed-punc \ + --compute-CER 0 + + +(2) modified beam search +./zipformer_prompt_asr/decode_bert.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --beam-size 4 \ + --text-encoder-type BERT \ + --memory-layer 0 \ + --use-pre-text True \ + --use-style-prompt True \ + --max-prompt-lens 1000 \ + --style-text-transform mixed-punc \ + --pre-text-transform mixed-punc \ + --compute-CER 0 + +(3) Decode LibriSpeech + +./zipformer_prompt_asr/decode_bert.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --use-ls-test-set True \ + --beam-size 4 \ + --text-encoder-type BERT \ + --memory-layer 0 \ + --use-pre-text True \ + --use-style-prompt True \ + --max-prompt-lens 1000 \ + --style-text-transform mixed-punc \ + --pre-text-transform mixed-punc \ + --compute-CER 0 + +(4) Decode LibriSpeech + biasing list + +biasing_list=100 # could also be 0 + +./zipformer_prompt_asr/decode_bert.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --beam-size 4 \ + --use-ls-test-set True \ + --use-ls-context-list True \ + --biasing-level utterance \ + --ls-distractors $biasing_list \ + --post-normalization True \ + --text-encoder-type BERT \ + --max-prompt-lens 1000 \ + --style-text-transform mixed-punc \ + --pre-text-transform mixed-punc + + +""" + + +import argparse +import logging +import math +import warnings +from collections import defaultdict +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from beam_search import greedy_search, greedy_search_batch, modified_beam_search +from dataset import naive_triplet_text_sampling, random_shuffle_subset +from ls_text_normalization import word_normalization +from text_normalization import ( + _apply_style_transform, + lower_all_char, + lower_only_alpha, + ref_text_normalization, + remove_non_alphabetic, + train_text_normalization, + upper_all_char, + upper_only_alpha, +) +from train_bert_encoder import ( + _encode_texts_as_bytes_with_tokenizer, + add_model_arguments, + get_params, + get_tokenizer, + get_transducer_model, +) +from transformers import BertModel, BertTokenizer +from utils import brian_biasing_list, get_facebook_biasing_list, write_error_stats + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_LODR + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-pre-text", + type=str2bool, + default=True, + help="Use pre-text is available during decoding", + ) + + parser.add_argument( + "--use-style-prompt", + type=str2bool, + default=True, + help="Use style prompt when evaluation", + ) + + parser.add_argument( + "--max-prompt-lens", + type=int, + default=1000, + ) + + parser.add_argument( + "--post-normalization", + type=str2bool, + default=True, + help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", + ) + + parser.add_argument( + "--compute-CER", + type=str2bool, + default=False, + help="Reports CER. By default, only reports WER", + ) + + parser.add_argument( + "--style-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of style prompt, i.e style_text", + ) + + parser.add_argument( + "--pre-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of content prompt, i.e pre_text", + ) + + parser.add_argument( + "--use-ls-test-set", + type=str2bool, + default=False, + help="Use librispeech test set for evaluation.", + ) + + parser.add_argument( + "--use-ls-context-list", + type=str2bool, + default=False, + help="If use a fixed context list for LibriSpeech decoding", + ) + + parser.add_argument( + "--biasing-level", + type=str, + default="utterance", + choices=["utterance", "Book", "Chapter"], + ) + + parser.add_argument( + "--ls-distractors", + type=int, + default=0, + help="The number of distractors into context list for LibriSpeech decoding", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + tokenizer: spm.SentencePieceProcessor, + batch: dict, + biasing_dict: dict = None, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + tokenizer: + Tokenizer for the text encoder + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + biasing_dict: + A dictionary in the form `{cut_id: :w1 w2"}` that contains a list + of biasing words (separated with space) + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + cuts = batch["supervisions"]["cut"] + cut_ids = [c.supervisions[0].id for c in cuts] + batch_size = feature.size(0) + + if "pre_text" in batch["supervisions"] and params.use_pre_text: + pre_texts = batch["supervisions"]["pre_text"] + pre_texts = [train_text_normalization(t) for t in pre_texts] + else: + pre_texts = ["" for _ in range(batch_size)] + + # get the librispeech biasing data + if params.use_pre_text and (params.use_ls_context_list and params.use_ls_test_set): + if params.biasing_level == "utterance": + pre_texts = [biasing_dict[id] for id in cut_ids] + elif params.biasing_level == "Chapter": + chapter_ids = [c.split("-")[1] for c in cut_ids] + pre_texts = [biasing_dict[id] for id in chapter_ids] + elif params.biasing_level == "Book": + chapter_ids = [c.split("-")[1] for c in cut_ids] + pre_texts = [biasing_dict[id] for id in chapter_ids] + else: + raise ValueError(f"Unseen biasing level: {params.biasing_level}") + if params.pre_text_transform == "mixed-punc": + pre_texts = [t.lower() for t in pre_texts] + + # get style_text + if params.use_style_prompt: + fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related." + style_texts = batch["supervisions"].get( + "style_text", [fixed_sentence for _ in range(batch_size)] + ) + style_texts = [train_text_normalization(t) for t in style_texts] + else: + style_texts = ["" for _ in range(batch_size)] # use empty string + + # Get the text embedding + if params.use_pre_text or params.use_style_prompt: + # apply style transform to the pre_text and style_text + pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) + if not params.use_ls_context_list: + pre_texts = [t[-params.max_prompt_lens :] for t in pre_texts] + + if params.use_style_prompt: + style_texts = _apply_style_transform( + style_texts, params.style_text_transform + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + # Use tokenizer to prepare input for text encoder + encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( + pre_texts=pre_texts, + style_texts=style_texts, + tokenizer=tokenizer, + device=device, + no_limit=True, + ) + logging.info( + f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}" + ) + + memory, memory_key_padding_mask = model.encode_text( + encoded_inputs=encoded_inputs, + style_lens=style_lens, + ) # (T,B,C) + else: + memory = None + memory_key_padding_mask = None + + # Get the transducer encoder output + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + encoder_out, encoder_out_lens = model.encode_audio( + feature=feature, + feature_lens=feature_lens, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + + hyps = [] + + if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + tokenizer: spm.SentencePieceProcessor, + biasing_dict: Dict = None, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + tokenizer: + Tokenizer for the text encoder + biasing_dict: + A dictionary in the form `{cut_id: :w1 w2"}` that contains a list + of biasing words (separated with space) + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"][ + "text" + ] # By default, this should be in mixed-punc format + + # the style of ref_text should match style_text + texts = _apply_style_transform(texts, params.style_text_transform) + if params.use_style_prompt: + texts = _apply_style_transform(texts, params.style_text_transform) + + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + if not params.use_ls_test_set: + try: + book_names = [ + cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"] + ] + except AttributeError: + book_names = [ + cut.id.split("/")[0] for cut in batch["supervisions"]["cut"] + ] + else: + book_names = ["" for _ in cut_ids] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + biasing_dict=biasing_dict, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, book_name, hyp_words, ref_text in zip( + cut_ids, book_names, hyps, texts + ): + ref_text = ref_text_normalization( + ref_text + ) # remove full-width symbols & some book marks + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + biasing_words: List[str] = None, +): + test_set_wers = dict() + test_set_cers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + if params.compute_CER: + # Write CER statistics + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results, char_level=True) + errs_filename = ( + params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + cer = write_error_stats( + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=params.compute_CER, + ) + test_set_cers[key] = cer + + logging.info("Wrote detailed CER stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + if params.compute_CER: + test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_cers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_cers: + s += "{} CER\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "modified_beam_search", + ) + + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_pre_text: + params.suffix += ( + f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}" + ) + + if params.use_style_prompt: + params.suffix += f"-style-prompt-{params.style_text_transform}" + + if params.use_ls_context_list: + assert ( + params.use_pre_text + ), "Must set --use-pre-text to True if using context list" + params.suffix += f"-use-{params.biasing_level}-level-ls-context-list" + if params.biasing_level == "utterance" and params.ls_distractors: + params.suffix += f"-ls-context-distractors-{params.ls_distractors}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + tokenizer = get_tokenizer(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to(device) + model.eval() + + LM = None + + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + libriheavy = LibriHeavyAsrDataModule(args) + + test_clean_cuts = libriheavy.test_clean_cuts() + test_other_cuts = libriheavy.test_other_cuts() + ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts() + ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() + + test_clean_dl = libriheavy.valid_dataloaders( + test_clean_cuts, text_sampling_func=naive_triplet_text_sampling + ) + test_other_dl = libriheavy.valid_dataloaders( + test_other_cuts, text_sampling_func=naive_triplet_text_sampling + ) + ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) + ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) + + if params.use_ls_test_set: + test_sets = ["ls-test-clean", "ls-test-other"] + test_dl = [ls_test_clean_dl, ls_test_other_dl] + else: + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + biasing_dict = None + if params.use_ls_context_list: + if test_set == "ls-test-clean": + biasing_dict = get_facebook_biasing_list( + test_set="test-clean", + num_distractors=params.ls_distractors, + ) + elif test_set == "ls-test-other": + biasing_dict = get_facebook_biasing_list( + test_set="test-other", + num_distractors=params.ls_distractors, + ) + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + biasing_dict=biasing_dict, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if params.post_normalization: + if "-post-normalization" not in params.suffix: + params.suffix += "-post-normalization" + + new_res = {} + for k in results_dict: + new_ans = [] + for item in results_dict[k]: + id, ref, hyp = item + if params.use_ls_test_set: + hyp = ( + " ".join(hyp).replace("-", " ").split() + ) # handle the hypens + hyp = upper_only_alpha(" ".join(hyp)).split() + hyp = [word_normalization(w.upper()) for w in hyp] + hyp = " ".join(hyp).split() + hyp = [w for w in hyp if w != ""] + ref = upper_only_alpha(" ".join(ref)).split() + else: + hyp = upper_only_alpha(" ".join(hyp)).split() + ref = upper_only_alpha(" ".join(ref)).split() + new_ans.append((id, ref, hyp)) + new_res[k] = new_ans + + save_results( + params=params, + test_set_name=test_set, + results_dict=new_res, + ) + + if params.suffix.endswith("-post-normalization"): + params.suffix = params.suffix.replace("-post-normalization", "") + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py new file mode 100755 index 000000000..4559ebb6d --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style_save_decoding_mp.py @@ -0,0 +1,963 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +""" + + +import argparse +import logging +import math +import warnings +from collections import defaultdict +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from beam_search import ( + greedy_search, + greedy_search_batch, + greedy_search_batch_with_context, + greedy_search_with_context, + modified_beam_search, +) +from dataset import naive_triplet_text_sampling, random_shuffle_subset +from lhotse import load_manifest_lazy +from text_normalization import ( + lower_all_char, + lower_only_alpha, + ref_text_normalization, + remove_non_alphabetic, + train_text_normalization, + upper_all_char, + upper_only_alpha, +) +from train_bert_encoder_with_style import ( + _encode_texts_as_bytes_with_tokenizer, + add_model_arguments, + get_params, + get_tokenizer, + get_transducer_model, +) +from transformers import BertModel, BertTokenizer +from utils import get_facebook_biasing_list + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--log-dir", + type=str, + required=True, + help="Where to store the logs", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_LODR + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--input-manifest", + type=str, + required=True, + help="The input manifest to be decoded", + ) + + parser.add_argument( + "--output-manifest", + type=str, + required=True, + help="Where to store the output manifest (directory)", + ) + + parser.add_argument( + "--use-pre-text", + type=str2bool, + default=True, + help="Use pre-text is available during decoding", + ) + + parser.add_argument( + "--use-style-prompt", + type=str2bool, + default=True, + help="Use style prompt when evaluation", + ) + + parser.add_argument( + "--use-context-embedding", + type=str2bool, + default=False, + help="Use context fuser when evaluation", + ) + + parser.add_argument( + "--post-normalization", + type=str2bool, + default=True, + help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", + ) + + parser.add_argument( + "--compute-CER", + type=str2bool, + default=True, + help="Reports CER. By default, only reports WER", + ) + + parser.add_argument( + "--style-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of style prompt, i.e style_text", + ) + + parser.add_argument( + "--pre-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of content prompt, i.e pre_text", + ) + + parser.add_argument( + "--use-ls-test-set", + type=str2bool, + default=False, + help="Use librispeech test set for evaluation.", + ) + + parser.add_argument( + "--use-ls-context-list", + type=str2bool, + default=False, + help="If use a fixed context list for LibriSpeech decoding", + ) + + add_model_arguments(parser) + + return parser + + +def _apply_style_transform(text: List[str], transform: str) -> List[str]: + """Apply transform to a list of text. By default, the text are in + ground truth format, i.e mixed-punc. + + Args: + text (List[str]): Input text string + transform (str): Transform to be applied + + Returns: + List[str]: _description_ + """ + if transform == "mixed-punc": + return text + elif transform == "upper-no-punc": + return [upper_only_alpha(s) for s in text] + elif transform == "lower-no-punc": + return [lower_only_alpha(s) for s in text] + elif transform == "lower-punc": + return [lower_all_char(s) for s in text] + else: + raise NotImplementedError(f"Unseen transform: {transform}") + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + tokenizer, + batch: dict, + biasing_dict: dict = None, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + cuts = batch["supervisions"]["cut"] + cut_ids = [c.supervisions[0].id for c in cuts] + batch_size = feature.size(0) + + # get pre_text + if "pre_text" in batch["supervisions"] and params.use_pre_text: + pre_texts = batch["supervisions"][ + "text" + ] # use the ground truth ref text as pre_text + pre_texts = [train_text_normalization(t) for t in pre_texts] + else: + pre_texts = ["" for _ in range(batch_size)] + + if params.use_ls_context_list: + pre_texts = [biasing_dict[id] for id in cut_ids] + + # get style_text + if params.use_style_prompt: + fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related." + style_texts = batch["supervisions"].get( + "style_text", [fixed_sentence for _ in range(batch_size)] + ) + style_texts = [train_text_normalization(t) for t in style_texts] + else: + style_texts = ["" for _ in range(batch_size)] # use empty string + + # Get the text embedding input + if params.use_pre_text or params.use_style_prompt: + + # apply style transform to the pre_text and style_text + pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) + # pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0) + if params.use_style_prompt: + style_texts = _apply_style_transform( + style_texts, params.style_text_transform + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + # Use tokenizer to prepare input for text encoder + encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( + pre_texts=pre_texts, + style_texts=style_texts, + tokenizer=tokenizer, + device=device, + ) + + memory, memory_key_padding_mask = model.encode_text( + encoded_inputs=encoded_inputs, + style_lens=style_lens, + ) # (T,B,C) + else: + memory = None + memory_key_padding_mask = None + + # Get the transducer encoder output + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + encoder_out, encoder_out_lens = model.encode_audio( + feature=feature, + feature_lens=feature_lens, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + + hyps = [] + + if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + if memory is None or not params.use_context_embedding: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + else: + memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C) + context = model.context_fuser( + memory, padding_mask=memory_key_padding_mask + ) # (N,C) + context = model.joiner.context_proj(context) # (N,C) + hyp_tokens = greedy_search_batch_with_context( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context=context, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + if memory is None or not params.use_context_embedding: + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + else: + cur_context = context[i : i + 1, :] + hyp = greedy_search_with_context( + model=model, + encoder_out=encoder_out_i, + context=cur_context, + max_sym_per_frame=params.max_sym_per_frame, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + tokenizer, + biasing_dict: Dict = None, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 40 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"][ + "text" + ] # By default, this should be in mixed-punc format + + # the style of ref_text should match style_text + texts = _apply_style_transform(texts, params.style_text_transform) + if params.use_style_prompt: + texts = _apply_style_transform(texts, params.style_text_transform) + + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + biasing_dict=biasing_dict, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_text = ref_text_normalization( + ref_text + ) # remove full-width symbols & some book marks + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + test_set_cers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + if params.compute_CER: + # Write CER statistics + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results, char_level=True) + errs_filename = ( + params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + cer = write_error_stats( + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=params.compute_CER, + ) + test_set_cers[key] = cer + + logging.info("Wrote detailed CER stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + if params.compute_CER: + test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_cers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_cers: + s += "{} CER\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +def add_decoding_result_to_manifest( + in_manifest, + out_manifest: str, + results_dict: Dict, +): + # write the decoding results with prompt to the manifest as an + # extra ref text + new_ans = {} + for key, value in results_dict.items(): + for items in value: + id, ref, hyp = items + new_ans[id] = " ".join(hyp) + + def _add_decoding(c): + key = c.supervisions[0].id + c.supervisions[0].texts.append(new_ans[key]) + return c + + in_manifest = in_manifest.map(_add_decoding) + logging.info(f"Saving manifest to {out_manifest}") + in_manifest.to_file(out_manifest) + + +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + cuts = load_manifest_lazy(args.input_manifest) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + splitted_cuts = cuts.split(num_splits=world_size) + mp.spawn( + run, args=(world_size, args, splitted_cuts), nprocs=world_size, join=True + ) + else: + run(rank=0, world_size=1, args=args, cuts=cuts) + + +@torch.no_grad() +def run(rank, world_size, args, cuts): + params = get_params() + params.update(vars(args)) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_pre_text: + params.suffix += f"-pre-text-{params.pre_text_transform}" + + if params.use_style_prompt: + params.suffix += f"-style-prompt-{params.style_text_transform}" + + params.suffix += f"-{rank}" + + world_size = params.world_size + + params.output_manifest = Path(params.output_manifest) + if world_size > 1: + cuts = cuts[rank] + out_name = params.output_manifest / f"with_decoding_job_{rank}.jsonl.gz" + else: + out_name = params.output_manifest / "with_decoding.jsonl.gz" + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + setup_logger(f"{params.log_dir}/log-get-manifest-with-decoding-{rank}") + logging.info("Decoding started") + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + tokenizer = get_tokenizer(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + LM = None + + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + libriheavy = LibriHeavyAsrDataModule(args) + + dl = libriheavy.valid_dataloaders( + cuts, text_sampling_func=naive_triplet_text_sampling + ) + + test_sets = ["test"] + test_dl = [dl] + + for test_set, test_dl in zip(test_sets, test_dl): + biasing_dict = None + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + biasing_dict=biasing_dict, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + # save_results( + # params=params, + # test_set_name=test_set, + # results_dict=results_dict, + # ) + + add_decoding_result_to_manifest( + in_manifest=cuts, + out_manifest=out_name, + results_dict=results_dict, + ) + + logging.info("Done!") + + +# torch.set_num_threads(1) +# torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py new file mode 100644 index 000000000..93e0f9f7e --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py @@ -0,0 +1,130 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from scaling import Balancer + + +class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int, + blank_id: int, + context_size: int, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit including blank. + decoder_dim: + Dimension of the input embedding, and of the decoder output. + blank_id: + The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + """ + super().__init__() + + self.embedding = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=decoder_dim, + padding_idx=blank_id, + ) + # the balancers are to avoid any drift in the magnitude of the + # embeddings, which would interact badly with parameter averaging. + self.balancer = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) + + self.blank_id = blank_id + + assert context_size >= 1, context_size + self.context_size = context_size + self.vocab_size = vocab_size + + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=decoder_dim, + out_channels=decoder_dim, + kernel_size=context_size, + padding=0, + groups=decoder_dim // 4, # group size == 4 + bias=False, + ) + self.balancer2 = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + y = y.to(torch.int64) + # this stuff about clamp() is a temporary fix for a mismatch + # at utterance start, we use negative ids in beam_search.py + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) + + embedding_out = self.balancer(embedding_out) + + if self.context_size > 1: + embedding_out = embedding_out.permute(0, 2, 1) + if need_pad is True: + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + embedding_out = F.relu(embedding_out) + embedding_out = self.balancer2(embedding_out) + + return embedding_out diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py b/egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py new file mode 100644 index 000000000..257facce4 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/encoder_interface.py @@ -0,0 +1,43 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +import torch.nn as nn + + +class EncoderInterface(nn.Module): + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (batch_size, input_seq_len, num_features) + containing the input features. + x_lens: + A tensor of shape (batch_size,) containing the number of frames + in `x` before padding. + Returns: + Return a tuple containing two tensors: + - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) + containing unnormalized probabilities, i.e., the output of a + linear layer. + - encoder_out_lens, a tensor of shape (batch_size,) containing + the number of frames in `encoder_out` before padding. + """ + raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py b/egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py new file mode 100644 index 000000000..e0bc556a8 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +""" +Export `model.state_dict()` + +- For non-streaming model: + +./zipformer_prompt_asr/export_PromptASR.py \ + --exp-dir ./zipformer_prompt_asr/exp \ + --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ + --epoch 50 \ + --avg 10 + +- For streaming model: + +./zipformer_prompt_asr/export_PromptASR.py \ + --exp-dir ./zipformer_prompt_asr/exp \ + --causal 1 \ + --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ + --epoch 50 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from torch import Tensor, nn +from train_bert_encoder import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + # if torch.cuda.is_available(): + # device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + assert params.jit is False, "Jit is not supported yet" + + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py b/egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py new file mode 100644 index 000000000..59f822748 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/joiner.py @@ -0,0 +1,86 @@ +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from scaling import ScaledLinear + + +class Joiner(nn.Module): + def __init__( + self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + context_dim: int = 512, + context_injection: bool = False, + ): + super().__init__() + + self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) + self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) + self.output_linear = nn.Linear(joiner_dim, vocab_size) + if context_injection: + self.context_proj = ScaledLinear( + context_dim, joiner_dim, initial_scale=0.25 + ) + else: + self.context_proj = None + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + context: torch.Tensor = None, + project_input: bool = True, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + context: + An embedding vector representing the previous context information + project_input: + If true, apply input projections encoder_proj and decoder_proj. + If this is false, it is the user's responsibility to do this + manually. + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + assert encoder_out.ndim == decoder_out.ndim == 4 + assert encoder_out.shape[:-1] == decoder_out.shape[:-1] + + if project_input: + if context: + logit = ( + self.encoder_proj(encoder_out) + + self.decoder_proj(decoder_out) + + self.context_proj(context) + ) + else: + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + else: + if context is not None: + logit = encoder_out + decoder_out + context.unsqueeze(1).unsqueeze(1) + else: + logit = encoder_out + decoder_out + + logit = self.output_linear(torch.tanh(logit)) + + return logit diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/ls_text_normalization.py b/egs/libriheavy/ASR/zipformer_prompt_asr/ls_text_normalization.py new file mode 100644 index 000000000..9a693ca4f --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/ls_text_normalization.py @@ -0,0 +1,153 @@ +import re + +words = { + 0: "zero", + 1: "one", + 2: "two", + 3: "three", + 4: "four", + 5: "five", + 6: "six", + 7: "seven", + 8: "eight", + 9: "nine", + 10: "ten", + 11: "eleven", + 12: "twelve", + 13: "thirteen", + 14: "fourteen", + 15: "fifteen", + 16: "sixteen", + 17: "seventeen", + 18: "eighteen", + 19: "nineteen", + 20: "twenty", + 30: "thirty", + 40: "forty", + 50: "fifty", + 60: "sixty", + 70: "seventy", + 80: "eighty", + 90: "ninety", +} +ordinal_nums = [ + "zeroth", + "first", + "second", + "third", + "fourth", + "fifth", + "sixth", + "seventh", + "eighth", + "ninth", + "tenth", + "eleventh", + "twelfth", + "thirteenth", + "fourteenth", + "fifteenth", + "sixteenth", + "seventeenth", + "eighteenth", + "nineteenth", + "twentieth", +] + +num_ordinal_dict = {num: ordinal_nums[num] for num in range(21)} + + +def year_to_words(num: int): + assert isinstance(num, int), num + # check if a num is representing a year + if num > 1500 and num < 2000: + return words[num // 100] + " " + num_to_words(num % 100) + elif num == 2000: + return "TWO THOUSAND" + elif num > 2000: + return "TWO THOUSAND AND " + num_to_words(num % 100) + else: + return num_to_words(num) + + +def num_to_words(num: int): + # Return the English words of a integer number + + # If this is a year number + if num > 1500 and num < 2030: + return year_to_words(num) + + if num < 20: + return words[num] + if num < 100: + if num % 10 == 0: + return words[num // 10 * 10] + else: + return words[num // 10 * 10] + " " + words[num % 10] + if num < 1000: + return words[num // 100] + " hundred and " + num_to_words(num % 100) + if num < 1000000: + return num_to_words(num // 1000) + " thousand " + num_to_words(num % 1000) + return num + + +def num_to_ordinal_word(num: int): + + return num_ordinal_dict.get(num, num_to_words(num)).upper() + + +def replace_full_width_symbol(s: str) -> str: + # replace full-width symbol with theri half width counterpart + s = s.replace("“", '"') + s = s.replace("”", '"') + s = s.replace("‘", "'") + s = s.replace("’", "'") + + return s + + +def decoding_normalization(text: str) -> str: + text = replace_full_width_symbol(text) + + # Only keep all alpha-numeric characters, hypen and apostrophe + text = text.replace("-", " ") + text = re.sub(r"[^a-zA-Z0-9\s']+", "", text) + return text + + +def word_normalization(word: str) -> str: + # 1 .Use full word for some abbreviation + # 2. Convert digits to english words + # 3. Convert ordinal number to english words + if word == "MRS": + return "MISSUS" + if word == "MR": + return "MISTER" + if word == "ST": + return "SAINT" + if word == "ECT": + return "ET CETERA" + if word.isnumeric(): + word = num_to_words(int(word)) + return str(word).upper() + # e.g 9TH, 6TH + if word[-2:] == "TH" and word[0].isnumeric(): + return num_to_ordinal_word(int(word[:-2])).upper() + if word[0] == "'": + return word[1:] + + return word + + +def simple_normalization(text: str) -> str: + text = replace_full_width_symbol(text) + text = text.replace("--", " ") + + return text + + +if __name__ == "__main__": + + s = str(1830) + out = word_normalization(s) + print(s, out) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py new file mode 100644 index 000000000..77b4057c4 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py @@ -0,0 +1,262 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +import warnings +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ScaledLinear, penalize_abs_values_gt +from torch import Tensor + +from icefall.utils import add_sos, make_pad_mask + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, + vocab_size, + initial_scale=0.25, + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, + vocab_size, + initial_scale=0.25, + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + text: + A 2-D tensor of integer dtype containing prompt text, of shape (N, T). + It is exptected to contain the style prompt (first) and then the content + prompt. + text_lens: + A 1-D tensor of shape (N,). It contains the number of elements (bytes) + in `text` before padding, which will include the lengths of the + style plus the content prompt. + style_lens: + A 1-D tensor of shape (N,), containing the number of elements (bytes) + within each row of `text` that correspond to the style prompt (these + are expected to come first). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + x, x_lens = self.encoder_embed(x, x_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, x_lens = self.encoder( + x, + x_lens, + src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) + + def encode_audio( + self, + feature: Tensor, + feature_lens: Tensor, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """Encode the input audio features + + Args: + feature (Tensor): Input audio (N,T,C) + feature_lens (Tensor): Length of input audio (N,) + Returns: + Tuple[Tensor, Tensor]: Encoded acoustic features and length + """ + x, x_lens = self.encoder_embed(feature, feature_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder( + x=x, + x_lens=x_lens, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py new file mode 100644 index 000000000..21c7b4fac --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py @@ -0,0 +1,392 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +import warnings +from typing import Dict, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ScaledLinear, penalize_abs_values_gt +from torch import Tensor + +from icefall.utils import add_sos, make_pad_mask + + +class PromptedTransducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + text_encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + use_BERT: bool = True, + text_encoder_type: str = "BERT", + text_encoder_adapter: bool = False, + freeze_text_encoder: bool = True, + context_fuser: nn.Module = None, + ): + """ + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + text_encoder: + This is a encoder that processes text information (e.g content prompt + and style prompt). The input is `x` of (N,T) and `x_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + text_encoder_type: + The type of the text_encoder. Supported are (BERT, DistilBERT) + context_fuser + A optional module that fuses the embeddings of text encoder. The fused embedding + will be added to the joiner. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.text_encoder = text_encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, + vocab_size, + initial_scale=0.25, + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, + vocab_size, + initial_scale=0.25, + ) + + self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT + self.context_fuser = context_fuser + + assert text_encoder_type in ( + "BERT", + "DistilBERT", + "BERT-UNCASED", + ), f"Unseen text_encoder type {text_encoder_type}" + self.text_encoder_dim = ( + self.text_encoder.config.hidden_size + if text_encoder_type in ("BERT", "BERT-UNCASED") + else self.text_encoder.config.dim + ) + self.freeze_text_encoder = freeze_text_encoder + + if text_encoder_adapter: + self.text_encoder_adapter = nn.Sequential( + nn.Linear(self.text_encoder_dim, self.text_encoder_dim, bias=False), + nn.Tanh(), + ) + else: + self.text_encoder_adapter = None + + self.style_prompt_embedding = nn.Parameter( + torch.full((self.text_encoder_dim,), 0.5) + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + encoded_inputs: Dict, + style_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + use_pre_text: bool = True, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + text: + A 2-D tensor of integer dtype containing prompt text, of shape (N, T). + It is exptected to contain the style prompt (first) and then the content + prompt. + text_lens: + A 1-D tensor of shape (N,). It contains the number of elements (bytes) + in `text` before padding, which will include the lengths of the + style plus the content prompt. + style_lens: + A 1-D tensor of shape (N,), containing the number of elements (bytes) + within each row of `text` that correspond to the style prompt (these + are expected to come first). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + if self.freeze_text_encoder: + self.text_encoder.eval() + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + x, x_lens = self.encoder_embed(x, x_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # freeze the BERT text encoder + + if use_pre_text: + memory, memory_key_padding_mask = self.encode_text( + encoded_inputs, style_lens=style_lens + ) + else: + memory = None + memory_key_padding_mask = None + + encoder_out, x_lens = self.encoder( + x, + x_lens, + src_key_padding_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + if self.context_fuser is not None and memory is not None: + memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C) + context = self.context_fuser(memory, padding_mask=memory_key_padding_mask) + context = self.joiner.context_proj(context) + else: + context = None + + logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) + + def _add_style_indicator(self, memory: Tensor, style_lens: Tensor): + """ + Adds to `memory` an indicator that is 1.0 for positions that correspond to + the `style prompt` and 0 elsewhere. The scale can be fixed because the + scale of the embedding vector can adjust to compensate. + + Args: + memory: (memory_len, batch_size, embed_dim) + style_lens: (batch_size,), a vector of lengths of the style prompt. + """ + + (memory_len, batch_size, embed_dim) = memory.shape + + indicator = ( + torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens + ) + indicator = indicator.to(memory.dtype) + + extra_term = torch.zeros_like(memory) + extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand( + memory_len, batch_size, self.text_encoder_dim + ) + + return memory + extra_term + + def encode_text( + self, + encoded_inputs: Dict, + style_lens: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Get the embeddings of text + + Args: + encoded_inputs: The encoded inputs generated by a tokenizer (Dict) + + Returns: + Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the + text_encoder and the attention mask + """ + text_lens = encoded_inputs.pop("length") # need to use pop to remove this item + + # Freeze the pre-trained text encoder + with torch.no_grad(): + memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C) + memory = memory.permute(1, 0, 2) + + # Text encoder adapter + if self.text_encoder_adapter is not None: + memory = self.text_encoder_adapter(memory) + + memory = self._add_style_indicator(memory, style_lens) + + memory_key_padding_mask = make_pad_mask(text_lens) + + return memory, memory_key_padding_mask + + def encode_audio( + self, + feature: Tensor, + feature_lens: Tensor, + memory: Optional[Tensor], + memory_key_padding_mask: Optional[Tensor], + ) -> Tuple[Tensor, Tensor]: + """Encode the input audio features + + Args: + feature (Tensor): Input audio (N,T,C) + feature_lens (Tensor): Length of input audio (N,) + memory (Tensor): Embeddings from the text encoder + memory_key_padding_mask (Tensor): _description_ + + Returns: + Tuple[Tensor, Tensor]: _description_ + """ + x, x_lens = self.encoder_embed(feature, feature_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder( + x=x, + x_lens=x_lens, + src_key_padding_mask=src_key_padding_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +Transducer = PromptedTransducer # for decoding diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py new file mode 100644 index 000000000..a767761eb --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py @@ -0,0 +1,1168 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import logging +import random +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import torch +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.optim import Optimizer + + +class BatchedOptimizer(Optimizer): + """ + This class adds to class Optimizer the capability to optimize parameters in batches: + it will stack the parameters and their grads for you so the optimizer can work + on tensors with an extra leading dimension. This is intended for speed with GPUs, + as it reduces the number of kernels launched in the optimizer. + + Args: + params: + """ + + def __init__(self, params, defaults): + super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager + def batched_params(self, param_group, group_params_names): + """ + This function returns (technically, yields) a list of + of tuples (p, state), where + p is a `fake` parameter that is stacked (over axis 0) from real parameters + that share the same shape, and its gradient is also stacked; + `state` is the state corresponding to this batch of parameters + (it will be physically located in the "state" for one of the real + parameters, the last one that has any particular shape and dtype). + + This function is decorated as a context manager so that it can + write parameters back to their "real" locations. + + The idea is, instead of doing: + + for p in group["params"]: + state = self.state[p] + ... + + you can do: + + with self.batched_params(group["params"]) as batches: + for p, state, p_names in batches: + ... + + + Args: + group: a parameter group, which is a list of parameters; should be + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. + """ + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str + + assert len(param_group) == len(group_params_names) + for p, named_p in zip(param_group, group_params_names): + key = (str(p.dtype), *p.shape) + batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted( + range(len(batches_names)), key=lambda i: batches_names_keys[i] + ) + batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] + + stacked_params_dict = dict() + + # turn batches into a list, in deterministic order. + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] + + for batch, batch_names in zip(batches, batches_names): + p = batch[0] + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) + p_stacked.grad = grad + stacked_params_dict[key] = p_stacked + tuples.append((p_stacked, state, batch_names)) + + yield tuples # <-- calling code will do the actual optimization here! + + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for i, p in enumerate(batch): # batch is list of Parameter + p.copy_(stacked_params[i]) + + +class ScaledAdam(BatchedOptimizer): + """ + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + Unlike common optimizers, which accept model.parameters() or groups of parameters(), + this optimizer could accept model.named_parameters() or groups of named_parameters(). + See comments of function _get_names_of_parameters for its 4 possible cases. + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period + """ + + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + ): + + defaults = dict( + lr=lr, + clipping_scale=clipping_scale, + betas=betas, + scalar_lr_scale=scalar_lr_scale, + eps=eps, + param_min_rms=param_min_rms, + param_max_rms=param_max_rms, + scalar_max=scalar_max, + size_update_period=size_update_period, + clipping_update_period=clipping_update_period, + ) + + # If params only contains parameters or group of parameters, + # i.e when parameter names are not given, + # this flag will be set to False in funciton _get_names_of_parameters. + self.show_dominant_parameters = True + param_groups, parameters_names = self._get_names_of_parameters(params) + super(ScaledAdam, self).__init__(param_groups, defaults) + assert len(self.param_groups) == len(parameters_names) + self.parameters_names = parameters_names + + def _get_names_of_parameters( + self, params_or_named_params + ) -> Tuple[List[Dict], List[List[str]]]: + """ + Args: + params_or_named_params: according to the way ScaledAdam is initialized in train.py, + this argument could be one of following 4 cases, + case 1, a generator of parameter, e.g.: + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 2, a list of parameter groups with different config, e.g.: + model_param_groups = [ + {'params': model.encoder.parameters(), 'lr': 0.05}, + {'params': model.decoder.parameters(), 'lr': 0.01}, + {'params': model.joiner.parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) + + case 3, a generator of named_parameter, e.g.: + optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 4, a list of named_parameter groups with different config, e.g.: + model_named_param_groups = [ + {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, + {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, + {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) + + For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. + For case 3 and case 4, firstly, names and params are extracted from input named_params, + then, these extracted params are used to initialize the underlying torch.optimizer, + and these extracted names are mainly used by function + `_show_gradient_dominating_parameter` + + Returns: + Returns a tuple containing 2 elements: + - `param_groups` with type List[Dict], each Dict element is a parameter group. + An example of `param_groups` could be: + [ + {'params': `one iterable of Parameter`, 'lr': 0.05}, + {'params': `another iterable of Parameter`, 'lr': 0.08}, + {'params': `a third iterable of Parameter`, 'lr': 0.1}, + ] + - `param_gruops_names` with type List[List[str]], + each `List[str]` is for a group['params'] in param_groups, + and each `str` is the name of a parameter. + A dummy name "foo" is related to each parameter, + if input are params without names, i.e. case 1 or case 2. + """ + # variable naming convention in this function: + # p is short for param. + # np is short for named_param. + # p_or_np is short for param_or_named_param. + # cur is short for current. + # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. + # groups is a List[group] + + iterable_or_groups = list(params_or_named_params) + if len(iterable_or_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + + # The first value of returned tuple. A list of dicts containing at + # least 'params' as a key. + param_groups = [] + + # The second value of returned tuple, + # a List[List[str]], each sub-List is for a group. + param_groups_names = [] + + if not isinstance(iterable_or_groups[0], dict): + # case 1 or case 3, + # the input is an iterable of parameter or named parameter. + param_iterable_cur_group = [] + param_names_cur_group = [] + for p_or_np in iterable_or_groups: + if isinstance(p_or_np, tuple): + # case 3 + name, param = p_or_np + else: + # case 1 + assert isinstance(p_or_np, torch.Tensor) + param = p_or_np + # Assign a dummy name as a placeholder + name = "foo" + self.show_dominant_parameters = False + param_iterable_cur_group.append(param) + param_names_cur_group.append(name) + param_groups.append({"params": param_iterable_cur_group}) + param_groups_names.append(param_names_cur_group) + else: + # case 2 or case 4 + # the input is groups of parameter or named parameter. + for cur_group in iterable_or_groups: + assert "named_params" in cur_group + name_list = [x[0] for x in cur_group["named_params"]] + p_list = [x[1] for x in cur_group["named_params"]] + del cur_group["named_params"] + cur_group["params"] = p_list + param_groups.append(cur_group) + param_groups_names.append(name_list) + + return param_groups, param_groups_names + + def __setstate__(self, state): + super(ScaledAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group, group_params_names in zip(self.param_groups, self.parameters_names): + + with self.batched_params(group["params"], group_params_names) as batches: + + # batches is list of pairs (stacked_param, state). stacked_param is like + # a regular parameter, and will have a .grad, but the 1st dim corresponds to + # a stacking dim, it is not a real dim. + + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized + clipping_scale = 1 + else: + clipping_scale = self._get_clipping_scale(group, batches) + + for p, state, _ in batches: + # Perform optimization step. + # grad is not going to be None, we handled that when creating the batches. + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + # State initialization + if len(state) == 0: + self._init_state(group, p, state) + + self._step_one_batch(group, p, state, clipping_scale) + + return loss + + def _init_state(self, group: dict, p: Tensor, state: dict): + """ + Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p + is actually the batch dimension, corresponding to batched-together + parameters of a given shape. + + + Args: + group: Dict to look up configuration values. + p: The parameter that we are initializing the state for + state: Dict from string to whatever state we are initializing + """ + size_update_period = group["size_update_period"] + + state["step"] = 0 + + kwargs = {"device": p.device, "dtype": p.dtype} + + # 'delta' implements conventional momentum. There are + # several different kinds of update going on, so rather than + # compute "exp_avg" like in Adam, we store and decay a + # parameter-change "delta", which combines all forms of + # update. this is equivalent to how it's done in Adam, + # except for the first few steps. + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + batch_size = p.shape[0] + numel = p.numel() // batch_size + numel = p.numel() + + if numel > 1: + # "param_rms" just periodically records the scalar root-mean-square value of + # the parameter tensor. + # it has a shape like (batch_size, 1, 1, 1, 1) + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + state["param_rms"] = param_rms + + state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) + + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + def _get_clipping_scale( + self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] + ) -> float: + """ + Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients + by this amount before applying the rest of the update. + + Args: + group: the parameter group, an item in self.param_groups + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + assert len(tuples) >= 1 + clipping_scale = group["clipping_scale"] + (first_p, first_state, _) = tuples[0] + step = first_state["step"] + if clipping_scale is None or step == 0: + # no clipping. return early on step == 0 because the other + # parameters' state won't have been initialized yet. + return 1.0 + clipping_update_period = group["clipping_update_period"] + + tot_sumsq = torch.tensor(0.0, device=first_p.device) + for (p, state, param_names) in tuples: + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + if p.numel() == p.shape[0]: # a batch of scalars + tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] + else: + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + + tot_norm = tot_sumsq.sqrt() + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) + first_state["model_norms"][step % clipping_update_period] = tot_norm + + if step % clipping_update_period == 0: + # Print some stats. + # We don't reach here if step == 0 because we would have returned + # above. + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + quartiles = [] + for n in range(0, 5): + index = min( + clipping_update_period - 1, (clipping_update_period // 4) * n + ) + quartiles.append(sorted_norms[index].item()) + + median = quartiles[2] + threshold = clipping_scale * median + first_state["model_norm_threshold"] = threshold + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) + first_state["num_clipped"] = 0 + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) + + if step < clipping_update_period: + return 1.0 # We have not yet estimated a norm to clip to. + else: + try: + model_norm_threshold = first_state["model_norm_threshold"] + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) + return 1.0 + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans < 1.0: + first_state["num_clipped"] += 1 + if ans < 0.1: + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter(tuples, tot_sumsq) + return ans + + def _show_gradient_dominating_parameter( + self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + ): + """ + Show information of parameter which dominates tot_sumsq. + + Args: + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + tot_sumsq: sumsq of all parameters. Though it's could be calculated + from tuples, we still pass it to save some time. + """ + all_sumsq_orig = {} + for (p, state, batch_param_names) in tuples: + # p is a stacked batch parameters. + batch_grad = p.grad + if p.numel() == p.shape[0]: # a batch of scalars + batch_sumsq_orig = batch_grad**2 + # Dummy values used by following `zip` statement. + batch_rms_orig = torch.ones(p.shape[0]) + else: + batch_rms_orig = state["param_rms"] + batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + dim=list(range(1, batch_grad.ndim)) + ) + + for name, sumsq_orig, rms, grad in zip( + batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad + ): + + proportion_orig = sumsq_orig / tot_sumsq + all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) + + assert torch.isclose( + sum([value[0] for value in all_sumsq_orig.values()]).cpu(), + torch.tensor(1.0), + ) + sorted_by_proportion = { + k: v + for k, v in sorted( + all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True + ) + } + dominant_param_name = next(iter(sorted_by_proportion)) + ( + dominant_proportion, + dominant_sumsq, + dominant_rms, + dominant_grad, + ) = sorted_by_proportion[dominant_param_name] + logging.info( + f"Parameter dominating tot_sumsq {dominant_param_name}" + f" with proportion {dominant_proportion:.2f}," + f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" + f"={dominant_sumsq:.3e}," + f" grad_sumsq={(dominant_grad**2).sum():.3e}," + f" orig_rms_sq={(dominant_rms**2).item():.3e}" + ) + + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): + """ + Do the step for one parameter, which is actually going to be a batch of + `real` parameters, with dim 0 as the batch dim. + Args: + group: dict to look up configuration values + p: parameter to update (actually multiple parameters stacked together + as a batch) + state: state-dict for p, to look up the optimizer state + """ + lr = group["lr"] + size_update_period = group["size_update_period"] + beta1 = group["betas"][0] + + grad = p.grad + if clipping_scale != 1.0: + grad = grad * clipping_scale + step = state["step"] + delta = state["delta"] + + delta.mul_(beta1) + batch_size = p.shape[0] + numel = p.numel() // batch_size + if numel > 1: + # Update the size/scale of p, and set param_rms + scale_grads = state["scale_grads"] + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True + ) + if step % size_update_period == size_update_period - 1: + param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) + param_rms.copy_( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) + if step > 0: + # self._size_update() learns the overall scale on the + # parameter, by shrinking or expanding it. + self._size_update(group, scale_grads, p, state) + + if numel == 1: + # For parameters with 1 element we just use regular Adam. + # Updates delta. + self._step_scalar(group, p, state) + else: + self._step(group, p, state) + + state["step"] = step + 1 + + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: + """ + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p + """ + + param_rms = state["param_rms"] + beta1, beta2 = group["betas"] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_min_rms = group["param_min_rms"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + step = state["step"] + batch_size = p.shape[0] + + size_update_period = scale_grads.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2**size_update_period + + scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr**size_step + # we don't bother with bias_correction1; this will help prevent divergence + # at the start of training. + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) + + is_too_small = param_rms < param_min_rms + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + + # and ensure the parameter rms after update never exceeds param_max_rms. + # We have to look at the trained model for parameters at or around the + # param_max_rms, because sometimes they can indicate a problem with the + # topology or settings. + scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) + + delta = state["delta"] + # the factor of (1-beta1) relates to momentum. + delta.add_(p * scale_step, alpha=(1 - beta1)) + + def _step(self, group: dict, p: Tensor, state: dict): + """ + This function does the core update of self.step(), in the case where the members of + the batch have more than 1 element. + + Args: + group: A dict which will be used to look up configuration values + p: The parameter to be updated + grad: The grad of p + state: The state-dict corresponding to parameter p + + This function modifies p. + """ + grad = p.grad + lr = group["lr"] + beta1, beta2 = group["betas"] + eps = group["eps"] + param_min_rms = group["param_min_rms"] + step = state["step"] + + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + + this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) + bias_correction2 = 1 - beta2 ** (this_step + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + + denom = exp_avg_sq.sqrt() + denom += eps + grad = grad / denom + + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) + + delta = state["delta"] + delta.add_(grad * alpha) + p.add_(delta) + + def _step_scalar(self, group: dict, p: Tensor, state: dict): + """ + A simplified form of the core update for scalar tensors, where we cannot get a good + estimate of the parameter rms. + """ + beta1, beta2 = group["betas"] + scalar_max = group["scalar_max"] + eps = group["eps"] + lr = group["lr"] * group["scalar_lr_scale"] + grad = p.grad + + exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # bias_correction2 is like in Adam. Don't bother with bias_correction1; + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + denom = (exp_avg_sq / bias_correction2).sqrt() + eps + + delta = state["delta"] + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) + p.clamp_(min=-scalar_max, max=scalar_max) + p.add_(delta) + + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + + self.epoch = 0 + self.batch = 0 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[Union[int, float]] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.info( + f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + +class Eden(LRScheduler): + """ + Eden scheduler. + The basic formula (before warmup) is: + lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. + + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + lr_epochs: the number of epochs after which we start significantly + decreasing the learning rate, suggest 6 if you plan to do e.g. + 20 to 40 epochs, but may need smaller number if dataset is huge + and you will do few epochs. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + verbose: bool = False, + ): + super(Eden, self).__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.lr_epochs = lr_epochs + self.warmup_batches = warmup_batches + + def get_lr(self): + factor = ( + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + ) ** -0.25 * ( + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ) + + return [x * factor * warmup_factor for x in self.base_lrs] + + +def _test_eden(): + m = torch.nn.Linear(100, 100) + optim = ScaledAdam(m.parameters(), lr=0.03) + + scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + + logging.info(f"last lr = {scheduler.get_last_lr()}") + logging.info(f"state dict = {scheduler.state_dict()}") + + +# This is included mostly as a baseline for ScaledAdam. +class Eve(Optimizer): + """ + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular target_rms (default: 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + + + .. _Adam: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0 <= weight_decay <= 0.1: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError("AdamW does not support sparse gradients") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + group["eps"] + ) + + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] + + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + p.mul_(1 - (weight_decay * is_above_target_rms)) + + p.addcdiv_(exp_avg, denom, value=-step_size) + + if random.random() < 0.0005: + step = (exp_avg / denom) * step_size + logging.info( + f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + ) + + return loss + + +def _test_scaled_adam(hidden_dim: int): + import timeit + + from scaling import ScaledLinear + + E = 100 + B = 4 + T = 2 + logging.info("in test_eve_cain") + # device = torch.device('cuda') + device = torch.device("cpu") + dtype = torch.float32 + + fix_random_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + for iter in [1, 0]: + fix_random_seed(42) + Linear = torch.nn.Linear if iter == 0 else ScaledLinear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] + + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + scheduler.step_epoch() + # if epoch == 100 and iter in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 2 ** 22 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + lr = scheduler.get_last_lr()[0] + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + loss.log().backward() + optim.step() + optim.zero_grad() + scheduler.step_batch() + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Iter={iter}, Time taken: {stop - start}") + + logging.info(f"last lr = {scheduler.get_last_lr()}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) + import subprocess + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) + logging.info(s) + import sys + + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 + + _test_scaled_adam(hidden_dim) + _test_eden() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py b/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py new file mode 100644 index 000000000..48fd2612a --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script loads a checkpoint (`pretrained.pt`) and uses it to decode waves. +You can generate the checkpoint with the following command: + +./zipformer/export_PromptASR.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ + --epoch 50 \ + --avg 10 + +Utterance level context biasing: + +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ + --method modified_beam_search \ + --use-pre-text True \ + --content-prompt "bessy random words hello k2 ASR" \ + --use-style-prompt True \ + librispeech.flac + + +Word level context biasing: + +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500_fallback_coverage_0.99/tokens.txt \ + --method modified_beam_search \ + --use-pre-text True \ + --content-prompt "The topic is about horses." \ + --use-style-prompt True \ + test.wav + + +""" + +import argparse +import logging +import math +import warnings +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import greedy_search_batch, modified_beam_search +from text_normalization import _apply_style_transform, train_text_normalization +from torch.nn.utils.rnn import pad_sequence +from train_bert_encoder import ( + _encode_texts_as_bytes_with_tokenizer, + add_model_arguments, + get_params, + get_tokenizer, + get_transducer_model, +) + +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500_fallback_coverage_0.99/bpe.model", + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + parser.add_argument( + "--use-pre-text", + type=str2bool, + default=True, + help="Use content prompt during decoding", + ) + + parser.add_argument( + "--use-style-prompt", + type=str2bool, + default=True, + help="Use style prompt during decoding", + ) + + parser.add_argument( + "--pre-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of content prompt, i.e pre_text", + ) + + parser.add_argument( + "--style-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of style prompt, i.e style_text", + ) + + parser.add_argument( + "--content-prompt", type=str, default="", help="The content prompt for decoding" + ) + + parser.add_argument( + "--style-prompt", + type=str, + default="Mixed-cased English text with punctuations, feel free to change it.", + help="The style prompt for decoding", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + + logging.info("Creating model") + model = get_transducer_model(params) + tokenizer = get_tokenizer(params) # for text encoder + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + assert ( + len(params.sound_files) == 1 + ), "Only support decoding one audio at this moment" + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # encode prompts + if params.use_pre_text: + pre_text = [train_text_normalization(params.content_prompt)] + pre_text = _apply_style_transform(pre_text, params.pre_text_transform) + else: + pre_text = [""] + + if params.use_style_prompt: + style_text = [params.style_prompt] + style_text = _apply_style_transform(style_text, params.style_text_transform) + else: + style_text = [""] + + if params.use_pre_text or params.use_style_prompt: + encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( + pre_texts=pre_text, + style_texts=style_text, + tokenizer=tokenizer, + device=device, + no_limit=True, + ) + + memory, memory_key_padding_mask = model.encode_text( + encoded_inputs=encoded_inputs, + style_lens=style_lens, + ) # (T,B,C) + else: + memory = None + memory_key_padding_mask = None + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + encoder_out, encoder_out_lens = model.encode_audio( + feature=features, + feature_lens=feature_lengths, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + + hyps = [] + msg = f"Using {params.method}" + logging.info(msg) + + if params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + hyps.append(sp.decode(hyp_tokens)[0]) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + hyps.append(sp.decode(hyp_tokens)[0]) + else: + raise ValueError(f"Unsupported method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + s += f"{filename}:\n{hyp}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py b/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py new file mode 100644 index 000000000..0e6764ba0 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py @@ -0,0 +1,1872 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +import logging +import math +import random +from functools import reduce +from itertools import repeat +from typing import Optional, Tuple, Union + +import k2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn import Embedding as ScaledEmbedding + + +class PiecewiseLinear(object): + """ + Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with + the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] + respectively. + """ + + def __init__(self, *args): + assert len(args) >= 1 + if len(args) == 1 and isinstance(args[0], PiecewiseLinear): + self.pairs = list(args[0].pairs) + else: + self.pairs = [(float(x), float(y)) for x, y in args] + for (x, y) in self.pairs: + assert isinstance(x, float) or isinstance(x, int) + assert isinstance(y, float) or isinstance(y, int) + + for i in range(len(self.pairs) - 1): + assert self.pairs[i + 1][0] > self.pairs[i][0], self.pairs + + def __str__(self): + # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' + return f"PiecewiseLinear({str(self.pairs)[1:-1]})" + + def __call__(self, x): + if x <= self.pairs[0][0]: + return self.pairs[0][1] + elif x >= self.pairs[-1][0]: + return self.pairs[-1][1] + else: + cur_x, cur_y = self.pairs[0] + for i in range(1, len(self.pairs)): + next_x, next_y = self.pairs[i] + if x >= cur_x and x <= next_x: + return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) + cur_x, cur_y = next_x, next_y + assert False + + def __mul__(self, alpha): + return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) + + def __add__(self, x): + if isinstance(x, float) or isinstance(x, int): + return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) + s, x = self.get_common_basis(x) + return PiecewiseLinear( + *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def max(self, x): + if isinstance(x, float) or isinstance(x, int): + x = PiecewiseLinear((0, x)) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def min(self, x): + if isinstance(x, float) or isinstance(x, int): + x = PiecewiseLinear((0, x)) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def __eq__(self, other): + return self.pairs == other.pairs + + def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): + """ + Returns (self_mod, p_mod) which are equivalent piecewise lienar + functions to self and p, but with the same x values. + + p: the other piecewise linear function + include_crossings: if true, include in the x values positions + where the functions indicate by this and p crosss. + """ + assert isinstance(p, PiecewiseLinear) + + # get sorted x-values without repetition. + x_vals = sorted(set([x for x, y in self.pairs] + [x for x, y in p.pairs])) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + + if include_crossings: + extra_x_vals = [] + for i in range(len(x_vals) - 1): + if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): + # if the two lines in this subsegment potentially cross each other.. + diff_cur = abs(y_vals1[i] - y_vals2[i]) + diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) + # `pos`, between 0 and 1, gives the relative x position, + # with 0 being x_vals[i] and 1 being x_vals[i+1]. + pos = diff_cur / (diff_cur + diff_next) + extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) + extra_x_vals.append(extra_x_val) + if len(extra_x_vals) > 0: + x_vals = sorted(set(x_vals + extra_x_vals)) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( + PiecewiseLinear(*zip(x_vals, y_vals1)), + PiecewiseLinear(*zip(x_vals, y_vals2)), + ) + + +class ScheduledFloat(torch.nn.Module): + """ + This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); + it does not have a working forward() function. You are supposed to cast it to float, as + in, float(parent_module.whatever), and use it as something like a dropout prob. + + It is a floating point value whose value changes depending on the batch count of the + training loop. It is a piecewise linear function where you specifiy the (x,y) pairs + in sorted order on x; x corresponds to the batch index. For batch-index values before the + first x or after the last x, we just use the first or last y value. + + Example: + self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) + + `default` is used when self.batch_count is not set or in training or mode or in + torch.jit scripting mode. + """ + + def __init__(self, *args, default: float = 0.0): + super().__init__() + # self.batch_count and self.name will be written to in the training loop. + self.batch_count = None + self.name = None + self.default = default + self.schedule = PiecewiseLinear(*args) + + def extra_repr(self) -> str: + return ( + f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" + ) + + def __float__(self): + batch_count = self.batch_count + if batch_count is None or not self.training or torch.jit.is_scripting(): + return float(self.default) + else: + ans = self.schedule(self.batch_count) + if random.random() < 0.0002: + logging.info( + f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" + ) + return ans + + def __add__(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule + x, default=self.default) + else: + return ScheduledFloat( + self.schedule + x.schedule, default=self.default + x.default + ) + + def max(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule.max(x), default=self.default) + else: + return ScheduledFloat( + self.schedule.max(x.schedule), default=max(self.default, x.default) + ) + + +FloatLike = Union[float, ScheduledFloat] + + +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: + """ + A randomized way of casting a floating point value to half precision. + """ + if x.dtype == torch.float16: + return x + x_abs = x.abs() + is_too_small = x_abs < min_abs + # for elements where is_too_small is true, random_val will contain +-min_abs with + # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, + # for those elements]. + random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) + return torch.where(is_too_small, random_val, x).to(torch.float16) + + +class CutoffEstimator: + """ + Estimates cutoffs of an arbitrary numerical quantity such that a specified + proportion of items will be above the cutoff on average. + + p is the proportion of items that should be above the cutoff. + """ + + def __init__(self, p: float): + self.p = p + # total count of items + self.count = 0 + # total count of items that were above the cutoff + self.count_above = 0 + # initial cutoff value + self.cutoff = 0 + + def __call__(self, x: float) -> bool: + """ + Returns true if x is above the cutoff. + """ + ans = x > self.cutoff + self.count += 1 + if ans: + self.count_above += 1 + cur_p = self.count_above / self.count + delta_p = cur_p - self.p + if (delta_p > 0) == ans: + q = abs(delta_p) + self.cutoff = x * q + self.cutoff * (1 - q) + return ans + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.float16) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + (ans,) = ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, dim: int): + return SoftmaxFunction.apply(x, dim) + + +class MaxEigLimiterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: + ctx.channel_dim = channel_dim + ctx.grad_scale = grad_scale + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) + return x + + @staticmethod + def backward(ctx, x_grad, *args): + with torch.enable_grad(): + (x_orig, coeffs, new_direction) = ctx.saved_tensors + x_orig.requires_grad = True + num_channels = x_orig.shape[ctx.channel_dim] + x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) + new_direction.requires_grad = False + x = x - x.mean(dim=0) + x_var = (x**2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual**2).mean() + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. This is to be minimized. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + variance_proportion.backward() + x_orig_grad = x_orig.grad + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) + return x_grad + x_extra_grad.detach(), None, None, None, None + + +class BiasNormFunction(torch.autograd.Function): + # This computes: + # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() + # return (x - bias) * scales + # (after unsqueezing the bias), but it does it in a memory-efficient way so that + # it can just store the returned value (chances are, this will also be needed for + # some other reason, related to the next operation, so we can save memory). + @staticmethod + def forward( + ctx, + x: Tensor, + bias: Tensor, + log_scale: Tensor, + channel_dim: int, + store_output_for_backprop: bool, + ) -> Tensor: + assert bias.ndim == 1 + if channel_dim < 0: + channel_dim = channel_dim + x.ndim + ctx.store_output_for_backprop = store_output_for_backprop + ctx.channel_dim = channel_dim + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() + ans = x * scales + ctx.save_for_backward( + ans.detach() if store_output_for_backprop else x, + scales.detach(), + bias.detach(), + log_scale.detach(), + ) + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + ans_or_x, scales, bias, log_scale = ctx.saved_tensors + if ctx.store_output_for_backprop: + x = ans_or_x / scales + else: + x = ans_or_x + x = x.detach() + x.requires_grad = True + bias.requires_grad = True + log_scale.requires_grad = True + with torch.enable_grad(): + # recompute scales from x, bias and log_scale. + scales = ( + torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() + ans = x * scales + ans.backward(gradient=ans_grad) + return x.grad, bias.grad.flatten(), log_scale.grad, None, None + + +class BiasNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + Instead, we give the BiasNorm a trainable bias that it can use when + computing the scale for normalization. We also give it a (scalar) + trainable scale on the output. + + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + log_scale: the initial log-scale that we multiply the output by; this + is learnable. + log_scale_min: FloatLike, minimum allowed value of log_scale + log_scale_max: FloatLike, maximum allowed value of log_scale + store_output_for_backprop: only possibly affects memory use; recommend + to set to True if you think the output of this module is more likely + than the input of this module to be required to be stored for the + backprop. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + log_scale: float = 1.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, + store_output_for_backprop: bool = False, + ) -> None: + super(BiasNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.log_scale = nn.Parameter(torch.tensor(log_scale)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + + self.log_scale_min = log_scale_min + self.log_scale_max = log_scale_max + + self.store_output_for_backprop = store_output_for_backprop + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + + if torch.jit.is_scripting(): + channel_dim = self.channel_dim + if channel_dim < 0: + channel_dim += x.ndim + bias = self.bias + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * self.log_scale.exp() + return x * scales + + log_scale = limit_param_value( + self.log_scale, + min=float(self.log_scale_min), + max=float(self.log_scale_max), + training=self.training, + ) + + return BiasNormFunction.apply( + x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop + ) + + +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv1d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: + """ + Behaves like a constructor of a modified version of nn.Conv2d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False, but: + NO PADDING-RELATED ARGS. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv2d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +class ChunkCausalDepthwiseConv1d(torch.nn.Module): + """ + Behaves like a depthwise 1d convolution, except that it is causal in + a chunkwise way, as if we had a block-triangular attention mask. + The chunk size is provided at test time (it should probably be + kept in sync with the attention mask). + + This has a little more than twice the parameters of a conventional + depthwise conv1d module: we implement it by having one + depthwise convolution, of half the width, that is causal (via + right-padding); and one depthwise convolution that is applied only + within chunks, that we multiply by a scaling factor which depends + on the position within the chunk. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + + def __init__( + self, + channels: int, + kernel_size: int, + initial_scale: float = 1.0, + bias: bool = True, + ): + super().__init__() + assert kernel_size % 2 == 1 + + half_kernel_size = (kernel_size + 1) // 2 + # will pad manually, on one side. + self.causal_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=half_kernel_size, + padding=0, + bias=True, + ) + + self.chunkwise_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=bias, + ) + + # first row is correction factors added to the scale near the left edge of the chunk, + # second row is correction factors added to the scale near the right edge of the chunk, + # both of these are added to a default scale of 1.0. + self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) + self.kernel_size = kernel_size + + with torch.no_grad(): + self.causal_conv.weight[:] *= initial_scale + self.chunkwise_conv.weight[:] *= initial_scale + if bias: + torch.nn.init.uniform_( + self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) + + def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: + """ + Forward function. Args: + x: a Tensor of shape (batch_size, channels, seq_len) + chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. + """ + (batch_size, num_channels, seq_len) = x.shape + + half_kernel_size = self.kernel_size + 1 // 2 + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + if chunk_size < 0 or chunk_size > seq_len: + chunk_size = seq_len + right_pad = -seq_len % chunk_size + + x = torch.nn.functional.pad(x, (left_pad, right_pad)) + + x_causal = self.causal_conv(x[..., : left_pad + seq_len]) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + num_chunks = x_chunk.shape[2] // chunk_size + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) + x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( + batch_size * num_chunks, num_channels, chunk_size + ) + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size) + + x_chunk = x_chunk * chunk_scale + x_chunk = x_chunk.reshape( + batch_size, num_chunks, num_channels, chunk_size + ).permute(0, 2, 1, 3) + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ + ..., :seq_len + ] + + return x_chunk + x_causal + + def _get_chunk_scale(self, chunk_size: int): + """Returns tensor of shape (num_channels, chunk_size) that will be used to + scale the output of self.chunkwise_conv.""" + left_edge = self.chunkwise_conv_scale[0] + right_edge = self.chunkwise_conv_scale[1] + if chunk_size < self.kernel_size: + left_edge = left_edge[:, :chunk_size] + right_edge = right_edge[:, -chunk_size:] + else: + t = chunk_size - self.kernel_size + channels = left_edge.shape[0] + pad = torch.zeros( + channels, t, device=left_edge.device, dtype=left_edge.dtype + ) + left_edge = torch.cat((left_edge, pad), dim=-1) + right_edge = torch.cat((pad, right_edge), dim=-1) + return 1.0 + (left_edge + right_edge) + + +class BalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + min_mean: float, + max_mean: float, + min_rms: float, + max_rms: float, + grad_scale: float, + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + ctx.save_for_backward(x) + ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + (x,) = ctx.saved_tensors + (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config + + try: + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x = x.to(torch.float32) + x = x.detach() + x.requires_grad = True + mean_dims = [i for i in range(x.ndim) if i != channel_dim] + uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) + mean = x.mean(dim=mean_dims, keepdim=True) + stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() + rms = uncentered_var.clamp(min=1.0e-20).sqrt() + + m = mean / stddev + # part of loss that relates to mean / stddev + m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() + + # put a much larger scale on the RMS-max-limit loss, so that if both it and the + # m_loss are violated we fix the RMS loss first. + rms_clamped = rms.clamp(min=min_rms, max=max_rms) + r_loss = (rms_clamped / rms).log().abs() + + loss = m_loss + r_loss + + loss.backward(gradient=torch.ones_like(loss)) + loss_grad = x.grad + loss_grad_rms = ( + (loss_grad**2) + .mean(dim=mean_dims, keepdim=True) + .sqrt() + .clamp(min=1.0e-20) + ) + + loss_grad = loss_grad * (grad_scale / loss_grad_rms) + + x_grad_float = x_grad.to(torch.float32) + # scale each element of loss_grad by the absolute value of the corresponding + # element of x_grad, which we view as a noisy estimate of its magnitude for that + # (frame and dimension). later we can consider factored versions. + x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) + x_grad = x_grad_mod.to(x_grad.dtype) + except Exception as e: + logging.info( + f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." + ) + + return x_grad, None, None, None, None, None, None + + +class Balancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + prob: determines the minimum probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + grad_scale: FloatLike = 0.04, + prob: Optional[FloatLike] = None, + ): + super().__init__() + + if prob is None: + prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) + self.prob = prob + # 5% of the time we will return and do nothing because memory usage is + # too high. + self.mem_cutoff = CutoffEstimator(0.05) + + # actually self.num_channels is no longer needed except for an assertion. + self.num_channels = num_channels + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.min_abs = min_abs + self.max_abs = max_abs + self.grad_scale = grad_scale + + def forward(self, x: Tensor) -> Tensor: + if ( + torch.jit.is_scripting() + or not x.requires_grad + or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) + ): + return _no_op(x) + + prob = float(self.prob) + if random.random() < prob: + # The following inner-functions convert from the way we historically specified + # these limitations, as limits on the absolute value and the proportion of positive + # values, to limits on the RMS value and the (mean / stddev). + def _abs_to_rms(x): + # for normally distributed data, if the expected absolute value is x, the + # expected rms value will be sqrt(pi/2) * x. + return 1.25331413732 * x + + def _proportion_positive_to_mean(x): + def _atanh(x): + eps = 1.0e-10 + # eps is to prevent crashes if x is exactly 0 or 1. + # we'll just end up returning a fairly large value. + return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 + + def _approx_inverse_erf(x): + # 1 / (sqrt(pi) * ln(2)), + # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions + # this approximation is extremely crude and gets progressively worse for + # x very close to -1 or +1, but we mostly care about the "middle" region + # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, + # and math.erf(0.0407316414078772) = 0.045935330944660666, + # which is pretty close to 0.05. + return 0.8139535143 * _atanh(x) + + # first convert x from the range 0..1 to the range -1..1 which the error + # function returns + x = -1 + (2 * x) + return _approx_inverse_erf(x) + + min_mean = _proportion_positive_to_mean(float(self.min_positive)) + max_mean = _proportion_positive_to_mean(float(self.max_positive)) + min_rms = _abs_to_rms(float(self.min_abs)) + max_rms = _abs_to_rms(float(self.max_abs)) + grad_scale = float(self.grad_scale) + + assert x.shape[self.channel_dim] == self.num_channels + + return BalancerFunction.apply( + x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim + ) + else: + return _no_op(x) + + +def penalize_abs_values_gt( + x: Tensor, limit: float, penalty: float, name: str = None +) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + + The name is for randomly printed debug info. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss, name) + # you must use x for something, or this will be ineffective. + return x + + +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, :: dim + 1] + assert x.shape == (batch, dim) + return x + + +def _whitening_metric(x: Tensor, num_groups: int): + """ + Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of + of the centered feature covariance are the same within each group's covariance matrix + and also between groups. + Args: + x: a Tensor of shape (*, num_channels) + num_groups: the number of groups of channels, a number >=1 that divides num_channels + Returns: + Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and + greater than 1.0 otherwise. + """ + assert x.dtype != torch.float16 + x = x.reshape(-1, x.shape[-1]) + (num_frames, num_channels) = x.shape + assert num_channels % num_groups == 0 + channels_per_group = num_channels // num_groups + x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) + # x now has shape (num_groups, num_frames, channels_per_group) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + x_covar_mean_diag = _diag(x_covar).mean() + # the following expression is what we'd get if we took the matrix product + # of each covariance and measured the mean of its trace, i.e. + # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) + # this metric will be >= 1.0; the larger it is, the less 'white' the data was. + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) + return metric + + +class WhiteningPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: + ctx.save_for_backward(x) + ctx.module = module + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors + w = ctx.module + + try: + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x_detached = x_orig.to(torch.float32).detach() + x_detached.requires_grad = True + + metric = _whitening_metric(x_detached, w.num_groups) + + if random.random() < 0.005 or __name__ == "__main__": + logging.info( + f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" + ) + + if metric < float(w.whitening_limit): + w.prob = w.min_prob + return x_grad, None + else: + w.prob = w.max_prob + metric.backward() + penalty_grad = x_detached.grad + scale = w.grad_scale * ( + x_grad.to(torch.float32).norm() + / (penalty_grad.norm() + 1.0e-20) + ) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None + except Exception as e: + logging.info( + f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." + ) + return x_grad, None + + +class Whiten(nn.Module): + def __init__( + self, + num_groups: int, + whitening_limit: FloatLike, + prob: Union[float, Tuple[float, float]], + grad_scale: FloatLike, + ): + """ + Args: + num_groups: the number of groups to divide the channel dim into before + whitening. We will attempt to make the feature covariance + within each group, after mean subtraction, as "white" as possible, + while having the same trace across all groups. + whitening_limit: a value greater than 1.0, that dictates how much + freedom we have to violate the constraints. 1.0 would mean perfectly + white, with exactly the same trace across groups; larger values + give more freedom. E.g. 2.0. + prob: the probability with which we apply the gradient modification + (also affects the grad scale). May be supplied as a float, + or as a pair (min_prob, max_prob) + + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights. + E.g. 0.02 (you may want to use smaller values than this if prob is large) + """ + super(Whiten, self).__init__() + assert num_groups >= 1 + assert float(whitening_limit) >= 1 + assert grad_scale >= 0 + self.num_groups = num_groups + self.whitening_limit = whitening_limit + self.grad_scale = grad_scale + + if isinstance(prob, float): + prob = (prob, prob) + (self.min_prob, self.max_prob) = prob + assert 0 < self.min_prob <= self.max_prob <= 1 + self.prob = self.max_prob + self.name = None # will be set in training loop + + def forward(self, x: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the input unmodified. + In the backward pass, it will modify the gradients to ensure that the + distribution in each group has close to (lambda times I) as the covariance + after mean subtraction, with the same lambda across groups. + For whitening_limit > 1, there will be more freedom to violate this + constraint. + + Args: + x: the input of shape (*, num_channels) + + Returns: + x, unmodified. You should make sure + you use the returned value, or the graph will be freed + and nothing will happen in backprop. + """ + grad_scale = float(self.grad_scale) + if not x.requires_grad or random.random() > self.prob or grad_scale == 0: + return _no_op(x) + else: + return WhiteningPenaltyFunction.apply(x, self) + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor, name: str): + ctx.y_shape = y.shape + if random.random() < 0.002 and name is not None: + loss_sum = y.sum().item() + logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + None, + ) + + +def with_loss(x, y, name): + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y, name) + + +class ScaleGradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, alpha: float) -> Tensor: + ctx.alpha = alpha + return x + + @staticmethod + def backward(ctx, grad: Tensor): + return grad * ctx.alpha, None + + +def scale_grad(x: Tensor, alpha: float): + return ScaleGradFunction.apply(x, alpha) + + +class ScaleGrad(nn.Module): + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + + def forward(self, x: Tensor) -> Tensor: + return scale_grad(x, self.alpha) + + +class LimitParamValue(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, min: float, max: float): + ctx.save_for_backward(x) + assert max >= min + ctx.min = min + ctx.max = max + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x,) = ctx.saved_tensors + # where x < ctx.min, ensure all grads are negative (this will tend to make + # x more positive). + x_grad = x_grad * torch.where( + torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 + ) + # where x > ctx.max, ensure all grads are positive (this will tend to make + # x more negative). + x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) + return x_grad, None, None + + +def limit_param_value( + x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True +): + # You apply this to (typically) an nn.Parameter during training to ensure that its + # (elements mostly) stays within a supplied range. This is done by modifying the + # gradients in backprop. + # It's not necessary to do this on every batch: do it only some of the time, + # to save a little time. + if training and random.random() < prob: + return LimitParamValue.apply(x, min, max) + else: + return x + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + x_dtype = x.dtype + if x.dtype == torch.float16: + x = x.to(torch.float32) + + s = torch.sigmoid(x - 1.0) + y = x * s + + if requires_grad: + deriv = y * (1 - s) + s + + # notes on derivative of x * sigmoid(x - 1): + # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 + # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund + # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. + # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which + # floors), should be expectation-preserving. + floor = -0.044 + ceil = 1.2 + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.043637 + ceil = 1.2 + + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class DoubleSwish(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting(): + return x * torch.sigmoid(x - 1.0) + return DoubleSwishFunction.apply(x) + + +# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. +class Dropout2(nn.Module): + def __init__(self, p: FloatLike): + super().__init__() + self.p = p + + def forward(self, x: Tensor) -> Tensor: + return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) + + +class MulForDropout3(torch.autograd.Function): + # returns (x * y * alpha) where alpha is a float and y doesn't require + # grad and is zero-or-one. + @staticmethod + @custom_fwd + def forward(ctx, x, y, alpha): + assert not y.requires_grad + ans = x * y * alpha + ctx.save_for_backward(ans) + ctx.alpha = alpha + return ans + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad): + (ans,) = ctx.saved_tensors + x_grad = ctx.alpha * ans_grad * (ans != 0) + return x_grad, None, None + + +# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, +# and it lets you choose one dimension to share the dropout mask over +class Dropout3(nn.Module): + def __init__(self, p: FloatLike, shared_dim: int): + super().__init__() + self.p = p + self.shared_dim = shared_dim + + def forward(self, x: Tensor) -> Tensor: + p = float(self.p) + if not self.training or p == 0: + return _no_op(x) + scale = 1.0 / (1 - p) + rand_shape = list(x.shape) + rand_shape[self.shared_dim] = 1 + mask = torch.rand(*rand_shape, device=x.device) > p + ans = MulForDropout3.apply(x, mask, scale) + return ans + + +class SwooshLFunction(torch.autograd.Function): + """ + swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + x_dtype = x.dtype + + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + coeff = -0.08 + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 + + if not requires_grad: + return y + y.backward(gradient=torch.ones_like(y)) + + grad = x.grad + floor = coeff + ceil = 1.0 + coeff + 0.005 + + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + + coeff = -0.08 + floor = coeff + ceil = 1.0 + coeff + 0.005 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class SwooshL(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation.""" + if torch.jit.is_scripting(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 + if not x.requires_grad: + return k2.swoosh_l_forward(x) + else: + return k2.swoosh_l(x) + # return SwooshLFunction.apply(x) + + +class SwooshRFunction(torch.autograd.Function): + """ + swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + + derivatives are between -0.08 and 0.92. + + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + x_dtype = x.dtype + + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 + + if not requires_grad: + return y + y.backward(gradient=torch.ones_like(y)) + + grad = x.grad + floor = -0.08 + ceil = 0.925 + + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.08 + ceil = 0.925 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class SwooshR(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation.""" + if torch.jit.is_scripting(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 + if not x.requires_grad: + return k2.swoosh_r_forward(x) + else: + return k2.swoosh_r(x) + # return SwooshRFunction.apply(x) + + +# simple version of SwooshL that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshLForward(x: Tensor): + x_offset = x - 4.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) + return log_sum - 0.08 * x - 0.035 + + +# simple version of SwooshR that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshRForward(x: Tensor): + x_offset = x - 1.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) + return log_sum - 0.08 * x - 0.313261687 + + +class ActivationDropoutAndLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + activation: str, + dropout_p: float, + dropout_shared_dim: Optional[int], + ): + if dropout_p != 0.0: + dropout_shape = list(x.shape) + if dropout_shared_dim is not None: + dropout_shape[dropout_shared_dim] = 1 + # else it won't be very memory efficient. + dropout_mask = (1.0 / (1.0 - dropout_p)) * ( + torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p + ) + else: + dropout_mask = None + + ctx.save_for_backward(x, weight, bias, dropout_mask) + + ctx.activation = activation + + forward_activation_dict = { + "SwooshL": k2.swoosh_l_forward, + "SwooshR": k2.swoosh_r_forward, + } + # it will raise a KeyError if this fails. This will be an error. We let it + # propagate to the user. + activation_func = forward_activation_dict[activation] + x = activation_func(x) + if dropout_mask is not None: + x = x * dropout_mask + x = torch.nn.functional.linear(x, weight, bias) + return x + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad: Tensor): + saved = ctx.saved_tensors + (x, weight, bias, dropout_mask) = saved + + forward_and_deriv_activation_dict = { + "SwooshL": k2.swoosh_l_forward_and_deriv, + "SwooshR": k2.swoosh_r_forward_and_deriv, + } + # the following lines a KeyError if the activation is unrecognized. + # This will be an error. We let it propagate to the user. + func = forward_and_deriv_activation_dict[ctx.activation] + + y, func_deriv = func(x) + if dropout_mask is not None: + y = y * dropout_mask + # now compute derivative of y w.r.t. weight and bias.. + # y: (..., in_channels), ans_grad: (..., out_channels), + (out_channels, in_channels) = weight.shape + + in_channels = y.shape[-1] + g = ans_grad.reshape(-1, out_channels) + weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) + y_deriv = torch.matmul(ans_grad, weight) + bias_deriv = None if bias is None else g.sum(dim=0) + x_deriv = y_deriv * func_deriv + if dropout_mask is not None: + # order versus func_deriv does not matter + x_deriv = x_deriv * dropout_mask + + return x_deriv, weight_deriv, bias_deriv, None, None, None + + +class ActivationDropoutAndLinear(torch.nn.Module): + """ + This merges an activation function followed by dropout and then a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwooshL and dropout_shared_dim != None, this will be + equivalent to: + nn.Sequential(SwooshL(), + Dropout3(dropout_p, shared_dim=dropout_shared_dim), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + If dropout_shared_dim is None, the dropout would be equivalent to + Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout + mask is smaller. + + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwooshL. + dropout_p: the dropout probability or schedule (happens after nonlinearity). + dropout_shared_dim: the dimension, if any, across which the dropout mask is + shared (e.g. the time dimension). If None, this may be less memory + efficient if there are modules before this one that cache the input + for their backprop (e.g. Balancer or Whiten). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwooshL", + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + initial_scale: float = 1.0, + ): + super().__init__() + # create a temporary module of nn.Linear that we'll steal the + # weights and bias from + layer = ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=initial_scale + ) + + self.weight = layer.weight + # register_parameter properly handles making it a parameter when l.bias + # is None. I think there is some reason for doing it this way rather + # than just setting it to None but I don't know what it is, maybe + # something to do with exporting the module.. + self.register_parameter("bias", layer.bias) + + self.activation = activation + self.dropout_p = dropout_p + self.dropout_shared_dim = dropout_shared_dim + + def forward(self, x: Tensor): + if torch.jit.is_scripting(): + if self.activation == "SwooshL": + x = SwooshLForward(x) + elif self.activation == "SwooshR": + x = SwooshRForward(x) + else: + assert False, self.activation + return torch.nn.functional.linear(x, self.weight, self.bias) + + return ActivationDropoutAndLinearFunction.apply( + x, + self.weight, + self.bias, + self.activation, + float(self.dropout_p), + self.dropout_shared_dim, + ) + + +class ClipGradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, limit: float): + ctx.limit = limit + return x + + @staticmethod + def backward(ctx, x_grad, *args): + return x_grad.clamp(-ctx.limit, ctx.limit), None + + +def clip_grad(x: Tensor, limit: float): + return ClipGradFunction.apply(x, limit) + + +class AbsValuePenalizer(nn.Module): + """ + This module adds a penalty to the loss function when ever the absolute value of + any element of the input tensor exceeds a certain limit. + """ + + def __init__(self, limit: float, prob: float = 0.1, penalty: float = 1.0e-04): + super().__init__() + self.limit = limit + self.penalty = penalty + + self.prob = prob + self.name = None # will be set in training loop + + # 20% of the time we will return and do nothing because memory usage is + # too high. + self.mem_cutoff = CutoffEstimator(0.2) + + def forward(self, x: Tensor) -> Tensor: + if ( + torch.jit.is_scripting() + or not x.requires_grad + or not self.training + or random.random() > self.prob + ): + # or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) + return _no_op(x) # the _no_op op is to make our diagnostics code work. + + x = penalize_abs_values_gt( + x, limit=self.limit, penalty=self.penalty, name=self.name + ) + return x + + +def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: + if num_channels <= x.shape[-1]: + return x[..., :num_channels] + else: + shape = list(x.shape) + shape[-1] = num_channels - shape[-1] + zeros = torch.zeros(*shape, dtype=x.dtype, device=x.device) + return torch.cat((x, zeros), dim=-1) + + +def _test_whiten(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"_test_whiten(): proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + num_channels = 128 + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) + x = x.detach() + x.requires_grad = True + m = Balancer( + probs.numel(), + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + min_abs=0.0, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_sign: x = ", x) + print("_test_balancer_sign: y grad = ", y_grad) + print("_test_balancer_sign: x grad = ", x.grad) + + +def _test_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = x.detach() + x.requires_grad = True + m = Balancer( + magnitudes.numel(), + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + min_abs=0.2, + max_abs=0.7, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_magnitude: x = ", x) + print("_test_balancer_magnitude: y grad = ", y_grad) + print("_test_balancer_magnitude: x grad = ", x.grad) + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = DoubleSwish() + + tol = (1.2 - (-0.043637)) / 255.0 + torch.autograd.gradcheck(m, x, atol=tol) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swooshl_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshL() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swooshr_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshR() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:, 0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:, 0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + +def _test_piecewise_linear(): + p = PiecewiseLinear((0, 10.0)) + for x in [-100, 0, 100]: + assert p(x) == 10.0 + p = PiecewiseLinear((0, 10.0), (1, 0.0)) + for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: + print("x, y = ", x, y) + assert p(x) == y, (x, p(x), y) + + q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) + x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] + pq = p.max(q) + for x in x_vals: + y1 = max(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p.min(q) + for x in x_vals: + y1 = min(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p + q + for x in x_vals: + y1 = p(x) + q(x) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + + +def _test_activation_dropout_and_linear(): + in_channels = 20 + out_channels = 30 + + for bias in [True, False]: + # actually we don't test for dropout_p != 0.0 because forward functions will give + # different answers. This is because we are using the k2 implementation of + # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() + # internally, messing up the random state. + for dropout_p in [0.0]: + for activation in ["SwooshL", "SwooshR"]: + m1 = nn.Sequential( + SwooshL() if activation == "SwooshL" else SwooshR(), + Dropout3(p=dropout_p, shared_dim=-1), + ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=0.5 + ), + ) + m2 = ActivationDropoutAndLinear( + in_channels, + out_channels, + bias=bias, + initial_scale=0.5, + activation=activation, + dropout_p=dropout_p, + ) + with torch.no_grad(): + m2.weight[:] = m1[2].weight + if bias: + m2.bias[:] = m1[2].bias + # make sure forward gives same result. + x1 = torch.randn(10, in_channels) + x1.requires_grad = True + + # TEMP. + assert torch.allclose( + SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 + ) + + x2 = x1.clone().detach() + x2.requires_grad = True + seed = 10 + torch.manual_seed(seed) + y1 = m1(x1) + y_grad = torch.randn_like(y1) + y1.backward(gradient=y_grad) + torch.manual_seed(seed) + y2 = m2(x2) + y2.backward(gradient=y_grad) + + print( + f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" + ) + print("y1 = ", y1) + print("y2 = ", y2) + assert torch.allclose(y1, y2, atol=0.02) + assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) + if bias: + assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) + print("x1.grad = ", x1.grad) + print("x2.grad = ", x2.grad) + + def isclose(a, b): + # return true if cosine similarity is > 0.9. + return (a * b).sum() > 0.9 * ( + (a**2).sum() * (b**2).sum() + ).sqrt() + + # the SwooshL() implementation has a noisy gradient due to 1-byte + # storage of it. + assert isclose(x1.grad, x2.grad) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_piecewise_linear() + _test_softmax() + _test_whiten() + _test_balancer_sign() + _test_balancer_magnitude() + _test_swooshl_deriv() + _test_swooshr_deriv() + _test_activation_dropout_and_linear() + _test_double_swish_deriv() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py b/egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py new file mode 100644 index 000000000..7acbc1808 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/subsampling.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Tuple + +import torch +from scaling import ( + Balancer, + BiasNorm, + Dropout3, + FloatLike, + Optional, + ScaledConv2d, + ScaleGrad, + ScheduledFloat, + SwooshL, + SwooshR, + Whiten, +) +from torch import Tensor, nn + + +class ConvNeXt(nn.Module): + """ + Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf + """ + + def __init__( + self, + channels: int, + hidden_ratio: int = 3, + kernel_size: Tuple[int, int] = (7, 7), + layerdrop_rate: FloatLike = None, + ): + super().__init__() + padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) + hidden_channels = channels * hidden_ratio + if layerdrop_rate is None: + layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015)) + self.layerdrop_rate = layerdrop_rate + + self.depthwise_conv = nn.Conv2d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=padding, + ) + + self.pointwise_conv1 = nn.Conv2d( + in_channels=channels, out_channels=hidden_channels, kernel_size=1 + ) + + self.hidden_balancer = Balancer( + hidden_channels, + channel_dim=1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + self.activation = SwooshL() + self.pointwise_conv2 = ScaledConv2d( + in_channels=hidden_channels, + out_channels=channels, + kernel_size=1, + initial_scale=0.01, + ) + + self.out_balancer = Balancer( + channels, + channel_dim=1, + min_positive=0.4, + max_positive=0.6, + min_abs=1.0, + max_abs=6.0, + ) + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or not self.training: + return self.forward_internal(x) + layerdrop_rate = float(self.layerdrop_rate) + + if layerdrop_rate != 0.0: + batch_size = x.shape[0] + mask = ( + torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) + > layerdrop_rate + ) + else: + mask = None + # turns out this caching idea does not work with --world-size > 1 + # return caching_eval(self.forward_internal, x, mask) + return self.forward_internal(x, mask) + + def forward_internal( + self, x: Tensor, layer_skip_mask: Optional[Tensor] = None + ) -> Tensor: + """ + x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) + + The returned value has the same shape as x. + """ + bypass = x + x = self.depthwise_conv(x) + x = self.pointwise_conv1(x) + x = self.hidden_balancer(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + if layer_skip_mask is not None: + x = x * layer_skip_mask + + x = bypass + x + x = self.out_balancer(x) + x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last + x = self.out_whiten(x) + x = x.transpose(1, 3) # (N, C, H, W) + + return x + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/2 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//2 - 2 == (T-7)//2 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + dropout: FloatLike = 0.1, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-3)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + bottleneck: + bottleneck dimension for 1d squeeze-excite + """ + assert in_channels >= 7 + super().__init__() + + # The ScaleGrad module is there to prevent the gradients + # w.r.t. the weight or bias of the first Conv2d module in self.conv from + # exceeding the range of fp16 when using automatic mixed precision (amp) + # training. (The second one is necessary to stop its bias from getting + # a too-large gradient). + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + ScaleGrad(0.2), + Balancer(layer1_channels, channel_dim=1, max_abs=1.0), + SwooshR(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + Balancer(layer2_channels, channel_dim=1, max_abs=4.0), + SwooshR(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + ), + Balancer(layer3_channels, channel_dim=1, max_abs=4.0), + SwooshR(), + ) + + # just one convnext layer + self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) + + out_width = (((in_channels - 1) // 2) - 1) // 2 + + self.out = nn.Linear(out_width * layer3_channels, out_channels) + # use a larger than normal grad_scale on this whitening module; there is + # only one such module, so there is not a concern about adding together + # many copies of this extra gradient term. + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0), + prob=(0.025, 0.25), + grad_scale=0.02, + ) + + # max_log_eps=0.0 is to prevent both eps and the output of self.out from + # getting large, there is an unnecessary degree of freedom. + self.out_norm = BiasNorm(out_channels) + self.dropout = Dropout3(dropout, shared_dim=1) + + def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + + Returns: + - a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + - output lengths, of shape (batch_size,) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) + # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite + # gradients. + x = self.conv(x) + x = self.convnext(x) + + # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + + x = x.transpose(1, 2).reshape(b, t, c * f) + # now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels)) + + x = self.out(x) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_whiten(x) + x = self.out_norm(x) + x = self.dropout(x) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = (x_lens - 7) // 2 + assert x.size(1) == x_lens.max().item() + + return x, x_lens diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py b/egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py new file mode 100755 index 000000000..13483637d --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/test_model.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless4/test_model.py +""" + +from scaling import ScheduledFloat +from train_subformer import get_params, get_text_encoder, get_transducer_model +from zipformer import Zipformer2 + + +def test_model_1(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = 24 + params.dim_feedforward = 1536 # 384 * 4 + params.encoder_dim = 384 + model = get_transducer_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf +def test_model_M(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,15,15" + + params.text_encoder_dim = (192, 192, 256, 384) + params.decoder_dim = 512 + params.joiner_dim = 512 + model = Zipformer2( + output_downsampling_factor=8, + downsampling_factor=(1, 2, 4, 8), + num_encoder_layers=(2, 4, 4, 4), + encoder_dim=(192, 192, 256, 384), + encoder_unmasked_dim=(192, 192, 256, 256), + query_head_dim=(32, 32, 32, 32), + pos_head_dim=(4, 4, 4, 4), + value_head_dim=(12, 12, 12, 12), + pos_dim=48, + num_heads=(4, 4, 4, 8), + feedforward_dim=( + 384, + 512, + 768, + 1024, + ), # could increase this if there is nough data + cnn_module_kernel=(31, 31, 15, 15), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=False, + ) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + model = Zipformer2( + output_downsampling_factor=8, + downsampling_factor=(1, 2, 4, 8), + num_encoder_layers=(2, 4, 6, 6), + encoder_dim=(256, 256, 384, 512), + encoder_unmasked_dim=(196, 196, 256, 256), + query_head_dim=(32, 32, 32, 32), + pos_head_dim=(4, 4, 4, 4), + value_head_dim=(12, 12, 12, 12), + pos_dim=48, + num_heads=(4, 4, 4, 8), + feedforward_dim=( + 384, + 512, + 768, + 1024, + ), # could increase this if there is nough data + cnn_module_kernel=(31, 31, 15, 15), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=False, + ) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +def main(): + # test_model_1() + test_model_M() + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py b/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py new file mode 100644 index 000000000..efb4acc3c --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py @@ -0,0 +1,101 @@ +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import List + + +def train_text_normalization(s: str) -> str: + # replace full-width with half-width + s = s.replace("“", '"') + s = s.replace("”", '"') + s = s.replace("‘", "'") + s = s.replace("’", "'") + if s[:2] == '" ': # remove the starting double quote + s = s[2:] + + return s + + +def ref_text_normalization(ref_text: str) -> str: + # Rule 1: Remove the [FN#[]] + p = r"[FN#[0-9]*]" + pattern = re.compile(p) + + res = pattern.findall(ref_text) + ref_text = re.sub(p, "", ref_text) + + ref_text = train_text_normalization(ref_text) + + return ref_text + + +def remove_non_alphabetic(text: str, strict: bool = True) -> str: + # Recommend to set strict to False + if not strict: + # Note, this also keeps space, single quote(') and hypen (-) + text = text.replace("-", " ") + text = text.replace("—", " ") + return re.sub(r"[^a-zA-Z0-9\s']+", "", text) + else: + # only keeps space + return re.sub(r"[^a-zA-Z\s]+", "", text) + + +def upper_only_alpha(text: str) -> str: + return remove_non_alphabetic(text.upper(), strict=False) + + +def lower_only_alpha(text: str) -> str: + return remove_non_alphabetic(text.lower(), strict=False) + + +def lower_all_char(text: str) -> str: + return text.lower() + + +def upper_all_char(text: str) -> str: + return text.upper() + + +def _apply_style_transform(text: List[str], transform: str) -> List[str]: + """Apply transform to a list of text. By default, the text are in + ground truth format, i.e mixed-punc. + + Args: + text (List[str]): Input text string + transform (str): Transform to be applied + + Returns: + List[str]: _description_ + """ + if transform == "mixed-punc": + return text + elif transform == "upper-no-punc": + return [upper_only_alpha(s) for s in text] + elif transform == "lower-no-punc": + return [lower_only_alpha(s) for s in text] + elif transform == "lower-punc": + return [lower_all_char(s) for s in text] + else: + raise NotImplementedError(f"Unseen transform: {transform}") + + +if __name__ == "__main__": + ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." + print(ref_text) + res = upper_only_alpha(ref_text) + print(res) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py new file mode 100644 index 000000000..7075c9154 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py @@ -0,0 +1,1390 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + + +# For mix precision training: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# To train a streaming model + +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --causal 1 + --exp-dir zipformer/exp \ + --max-duration 1000 + +""" + + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_baseline import Transducer +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from text_normalization import train_text_normalization, upper_only_alpha +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_first( + texts: List[str], + pre_texts: List[str], + context_list: Optional[str] = None, + rare_word_list: Optional[List[str]] = None, +) -> str: + # Always get the first one, which is the mixed-cased text with punc + out = {"text": texts[0], "pre_text": pre_texts[0]} + return out + + +def get_upper_only_alpha( + texts: List[str], + pre_texts: List[str], + context_list: Optional[str] = None, + rare_word_list: Optional[List[str]] = None, +) -> str: + # Always get the first one, which is the mixed-cased text with punc, + # but with upper case it and remove punctuation + out = { + "text": upper_only_alpha(texts[0]), + "pre_text": upper_only_alpha(pre_texts[0]), + } + return out + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--text-encoder-dim", + type=str, + default="256,256,384,512", + help="Embedding dimension in text encoder stacks: a comma-separated list of 4 elements, " + "or you should change other configs in the code.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(max(params.encoder_dim.split(","))), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + texts = [train_text_normalization(t) for t in texts] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + if random.random() < 0.02: + logging.info(f"Ref texts: {texts[0]}") + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + libriheavy = LibriHeavyAsrDataModule(args) + + train_cuts = libriheavy.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 30.0: + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].texts[0], out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].texts[0]}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + text_sampling_func = get_upper_only_alpha + logging.info(f"Text sampling func: {text_sampling_func}") + train_dl = libriheavy.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + text_sampling_func=text_sampling_func, + ) + + valid_cuts = libriheavy.dev_cuts() + valid_dl = libriheavy.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py new file mode 100755 index 000000000..e253d1118 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py @@ -0,0 +1,1798 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Xiaoyu Yang, +# +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For mix precision training: + +(1) Non-streaming model, **without** context list + +./zipformer_prompt_asr/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --subset medium \ + --causal False \ + --exp-dir zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --memory-layer 0 \ + --text-encoder-type BERT \ + --text-encoder-dim 768 \ + --use-style-prompt True \ + --use-context-list False + +(2) Non-streaming model, **with** context list + +./zipformer_prompt_asr/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --subset medium \ + --causal False \ + --exp-dir zipformer_prompt_asr/exp \ + --max-duration 1000 \ + --memory-layer 0 \ + --text-encoder-type BERT \ + --text-encoder-dim 768 \ + --use-style-prompt True \ + --use-context-list True \ + --top-k 10000 \ + --rare-word-file data/context_biasing/small_rare_words_topk_10000.txt + + +""" + + +import argparse +import copy +import logging +import os +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import numpy +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from dataset import ( + naive_triplet_text_sampling, + random_shuffle_subset, + triplet_text_sampling, + triplet_text_sampling_with_context_list, +) +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model_with_BERT import PromptedTransducer +from optim import Eden, ScaledAdam +from scaling import Balancer, BiasNorm, Dropout3, ScaleGrad, ScheduledFloat, SwooshR +from subsampling import Conv2dSubsampling +from text_normalization import ( + lower_all_char, + lower_only_alpha, + train_text_normalization, + upper_all_char, + upper_only_alpha, +) +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + +style_transforms = [ + lambda x: x, # return it self + upper_only_alpha, + lower_only_alpha, + lower_all_char, +] + + +def get_first(texts: List[str], pre_texts: List[str]) -> str: + out = { + "text": texts[0], + "pre_text": pre_texts[0], + "style_text": "", + "transform_ids": 0, + } + return out + + +def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str: + # Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha + out = { + "text": upper_only_alpha(texts[0]), + "pre_text": upper_only_alpha(pre_texts[0]), + "style_text": "", + "transform_ids": 0, + } + return out + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--memory-dropout-rate", + type=float, + default=0.05, + help="By which probability, dropout the memory when doing cross-attention.", + ) + + parser.add_argument( + "--memory-layer", + type=int, + default=0, + help="Start doing cross-attention from which layer. Zero-indexed", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--freeze-text-encoder", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--text-encoder-type", + type=str, + default="BERT", + choices=["BERT", "DistilBERT"], + help="Type of the text encoder", + ) + + parser.add_argument( + "--text-encoder-dim", + type=int, + default=768, + help="Dimension of the text encoder", + ) + + parser.add_argument( + "--text-encoder-adapter", + type=str2bool, + default=False, + help="An adapter for pre-trained BERT", + ) + + parser.add_argument( + "--context-injection", + type=str2bool, + default=False, + help="Inject context embedding into the joiner", + ) + + parser.add_argument( + "--context-dropout-rate", + type=float, + default=0.05, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-style-prompt", + type=str2bool, + default=True, + help="Whether to use style prompt.", + ) + + # arguments for using prompt + parser.add_argument( + "--pre-text-shuffle-prob", + type=float, + default=0.05, + help="The proportion of pre_text to be shuffled with in a batch", + ) + + parser.add_argument( + "--style-text-shuffle-prob", + type=float, + default=0.2, + help="The proportion of style_text to be shuffled with in a batch", + ) + + parser.add_argument( + "--prompt-mask-prob", + type=float, + default=0.05, + help="The probability of masking prompts", + ) + + parser.add_argument( + "--forced-upper-pre-text", + type=str2bool, + default=False, + help="Forced format of pre-text", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +class TextEmbedding(nn.Module): + def __init__( + self, + num_embeddings: int = 256, + embedding_dim: int = 256, + kernel_size: int = 3, + layer1_channels: int = 256, + layer2_channels: int = 256, + bias: bool = True, + dropout: float = 0.1, + ): + super().__init__() + self.embed = nn.Embedding( + num_embeddings=num_embeddings, # we encode the text as UTF-8 bytes + embedding_dim=embedding_dim, # + ) + + assert embedding_dim == layer1_channels # for depth wise convolution + self.conv = nn.Sequential( + nn.Conv1d( + embedding_dim, + layer1_channels, # depthwise convolution + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=layer1_channels, + bias=True, + ), + ScaleGrad(0.2), + Balancer(layer1_channels, channel_dim=1, min_positive=0.1, max_abs=1.0), + nn.ReLU(), + nn.Conv1d( + layer1_channels, + layer2_channels, + kernel_size=1, # pointwise convolution + stride=1, + padding=0, + bias=True, + ), + Balancer(layer2_channels, channel_dim=1, min_positive=0.1, max_abs=1.0), + nn.ReLU(), + ) + + self.out_norm = BiasNorm(layer2_channels) + self.dropout = Dropout3(dropout, shared_dim=1) + + def forward(self, text: torch.Tensor) -> torch.Tensor: + """Forward function of the text embedding + + Args: + text (torch.Tensor): Text in UTF-8 bytes (T,N) + Returns: + The embeddings of text (T,N,C) + """ + text = self.embed(text) # (T,N,C) + + # src = text + text = text.permute(1, 2, 0) # (T,N,C) -> (N,C,T) + text = self.conv(text) + text = text.permute(2, 0, 1) # (N,C,T) -> (T,N,C) + # src = src + text + + text = self.out_norm(text) + text = self.dropout(text) + + return text + + +def get_text_encoder(params: AttributeDict) -> nn.Module: + # Return a text encoder + if params.text_encoder_type == "BERT": # This is a BERT-base-cased + from transformers import BertModel + + logging.info("Loading pre-trained BERT-base-cased as text encoder") + if os.path.exists("data/models/bert-base-cased"): + model = BertModel.from_pretrained("data/models/bert-base-cased") + else: + model = BertModel.from_pretrained("bert-base-cased") + assert params.text_encoder_dim == 768 + elif params.text_encoder_type == "BERT-large": + from transformers import BertModel + + logging.info("Loading pre-trained BERT-large-uncased as text encoder") + if os.path.exists("data/models/bert-large-uncased"): + model = BertModel.from_pretrained("data/models/bert-large-uncased") + else: + model = BertModel.from_pretrained("bert-large-uncased") + assert params.text_encoder_dim == 1024 + elif params.text_encoder_type == "DistilBERT": + from transformers import DistilBertModel # This is a DistilBERT-base-cased + + logging.info("Loading pre-trained DistilBERT-base-cased as text encoder") + model = DistilBertModel.from_pretrained("distilbert-base-cased") + assert params.text_encoder_dim == 768 + else: + raise ValueError() + + return model + + +def get_tokenizer(params: AttributeDict): + + if params.text_encoder_type == "BERT": + from transformers import BertTokenizer + + # This is a BERT-base-cased + if os.path.exists("data/models/bert-base-cased"): + tokenizer = BertTokenizer.from_pretrained("data/models/bert-base-cased") + else: + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + elif params.text_encoder_type == "BERT-large": + from transformers import BertTokenizer + + # This is a BERT-large-uncased + if os.path.exists("data/models/bert-large-uncased"): + tokenizer = BertTokenizer.from_pretrained("data/models/bert-large-uncased") + else: + tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") + elif params.text_encoder_type == "DistilBERT": + from transformers import DistilBertTokenizer + + tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased") + else: + raise ValueError() + + return tokenizer + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + memory_dim=params.text_encoder_dim, # This is fixed as the BERT base model is 768-D + memory_layer=params.memory_layer, + memory_dropout_rate=params.memory_dropout_rate, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + context_dim=4 * 768 + if params.context_injection + else -1, # the output dim of text encoder + context_injection=params.context_injection, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + text_encoder = get_text_encoder(params) # This should be a cased BERT base model + num_param = sum([p.numel() for p in text_encoder.parameters()]) + logging.info(f"Num params in text encoder: {num_param}") + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = PromptedTransducer( + encoder_embed=encoder_embed, + encoder=encoder, + text_encoder=text_encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(max(params.encoder_dim.split(","))), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + text_encoder_type=params.text_encoder_type, + text_encoder_adapter=params.text_encoder_adapter, + freeze_text_encoder=params.freeze_text_encoder, + context_fuser=None, + ) + + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def _encode_texts_as_bytes_with_tokenizer( + pre_texts: List[str], + style_texts: List[str], + tokenizer, + device: torch.device, + max_len: int = 500, + no_limit: bool = False, +) -> Tuple[Dict, Tensor]: + """ + Encode texts as bytes and then integer tensors. + Note that the style text will be added to the beginning of texts. + """ + batch_size = len(pre_texts) + max_len = min(max_len, 500) + + if no_limit: + allowed_lens = [5000 - len(s) for s in style_texts] + else: + allowed_lens = [1000 - len(s) for s in style_texts] + truncated_pre_texts = [pre_texts[i][-allowed_lens[i] :] for i in range(batch_size)] + combined_text = [ + style_texts[i] + " [SEP] " + truncated_pre_texts[i] for i in range(batch_size) + ] + + encoded_style_texts = tokenizer( + style_texts, + return_tensors="pt", + padding=True, + truncation=True, + return_length=True, + max_length=max_len, + ) + style_lens = encoded_style_texts["length"].to(device) + + # Use tokenizer to prepare input for text encoder + encoded_inputs = tokenizer( + combined_text, + return_tensors="pt", + padding=True, + truncation=True, + return_length=True, + max_length=max_len, + ).to(device) + + return encoded_inputs, style_lens + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + batch_size = feature.size(0) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + pre_texts = batch["supervisions"]["pre_text"] + style_texts = batch["supervisions"][ + "style_text" + ] # the style texts are in gt format + transform_ids = batch["supervisions"]["transform_ids"] + + # This is to replace full-width symbols with half-width symbols + texts = [train_text_normalization(t) for t in texts] + pre_texts = [train_text_normalization(t) for t in pre_texts] + style_texts = [train_text_normalization(t) for t in style_texts] + + y = sp.encode( + texts, out_type=int + ) # sp.encode treats consecutive space as a single space + y = k2.RaggedTensor(y).to(device) + + if params.forced_upper_pre_text: + pre_texts = [upper_only_alpha(p) for p in pre_texts] + + # only shuffle the pre_text and style texts if during training, and use style prompt + if is_training: + # randomly shuffle&mask the pre_text + pre_texts = random_shuffle_subset( + pre_texts, + p=params.pre_text_shuffle_prob, + p_mask=params.prompt_mask_prob, + ) + + if params.use_style_prompt: + if random.random() < 0.5: + # randomly shuffle the style_text + # now the style_texts are all in gt format + style_texts = random_shuffle_subset( + style_texts, + p=params.style_text_shuffle_prob, + p_mask=params.prompt_mask_prob, + ) + + assert len(transform_ids) == len(style_texts) + + for i in range(len(style_texts)): + t = transform_ids[i] # get the transform id + style_texts[i] = style_transforms[t](style_texts[i]) + + if not params.use_style_prompt: + style_texts = [ + "" for _ in style_texts + ] # use empty string for style texts if don't use style prompt + + if random.random() < 0.05: + logging.info(f"Pre texts: {pre_texts[0]}") + logging.info(f"Ref texts: {texts[0]}") + logging.info(f"Style texts: {style_texts[0]}") + + encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( + pre_texts=pre_texts, + style_texts=style_texts, + tokenizer=tokenizer, + device=device, + ) + + if random.random() < 0.02: + logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ") + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + encoded_inputs=encoded_inputs, + style_lens=style_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + tokenizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if not params.use_style_prompt: + if params.pre_text_shuffle_prob == 0.0: + logging.info( + f"Pre_text shuffle prob is set to: {params.pre_text_shuffle_prob}" + ) + logging.info( + "If style prompt is not used, you should be careful when shuffling the pre_text within the same batch" + ) + logging.info("Hard set this probability to 0.0!") + params.pre_text_shuffle_prob = 0.0 + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + tokenizer = get_tokenizer(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + if params.freeze_text_encoder: + freeze_modules = ["text_encoder"] + logging.info( + "Freeze the parameters of text encoder and don't include them in the optimizer" + ) + else: + freeze_modules = [] + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True, freeze_modules=freeze_modules + ), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + args.max_duration = 100 + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + libriheavy = LibriHeavyAsrDataModule(args) + + train_cuts = libriheavy.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 30.0: + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].texts[0], out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].texts[0]}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + if params.use_context_list: + text_sampling_func = triplet_text_sampling_with_context_list + else: + text_sampling_func = triplet_text_sampling + + logging.info(f"Text sampling: {text_sampling_func}") + + train_dl = libriheavy.train_dataloaders( + train_cuts, + sampler_state_dict=sampler_state_dict, + text_sampling_func=text_sampling_func, + ) + + # For fair comparison, use fixed sampling in valid dataloaders + valid_cuts = libriheavy.dev_cuts() + valid_dl = libriheavy.valid_dataloaders( + valid_cuts, text_sampling_func=naive_triplet_text_sampling + ) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + tokenizer=tokenizer, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + tokenizer=tokenizer, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + tokenizer: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py b/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py new file mode 100644 index 000000000..ef0c48e8a --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py @@ -0,0 +1,515 @@ +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +python ./zipformer_prompt_asr/transcribe_bert.py \ + --epoch 50 \ + --avg 10 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --manifest-dir data/long_audios/long_audio.jsonl.gz \ + --pre-text-transform mixed-punc \ + --style-text-transform mixed-punc \ + --num-history 5 \ + --use-pre-text True \ + --use-gt-pre-text False + + +""" + +import argparse +import logging +import math +import warnings +from pathlib import Path +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from decode_bert import _apply_style_transform +from lhotse import Fbank, load_manifest +from text_normalization import ( + lower_all_char, + lower_only_alpha, + ref_text_normalization, + remove_non_alphabetic, + train_text_normalization, + upper_all_char, + upper_only_alpha, +) +from tqdm import tqdm +from train_bert_encoder import ( + _encode_texts_as_bytes_with_tokenizer, + add_model_arguments, + get_params, + get_tokenizer, + get_transducer_model, +) + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + ) + + parser.add_argument( + "--manifest-dir", + type=str, + default="data/long_audios/long_audio.jsonl.gz", + help="""This is the manfiest for long audio transcription. + The cust are intended to be sorted, i.e first sort by recording ID and + then sort by start timestamp""", + ) + + parser.add_argument( + "--use-pre-text", + type=str2bool, + default=False, + help="Whether use pre-text when decoding the current chunk", + ) + + parser.add_argument( + "--use-style-prompt", + type=str2bool, + default=True, + help="Use style prompt when evaluation", + ) + + parser.add_argument( + "--pre-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of content prompt, i.e pre_text", + ) + + parser.add_argument( + "--style-text-transform", + type=str, + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], + default="mixed-punc", + help="The style of style prompt, i.e style_text", + ) + + parser.add_argument( + "--num-history", + type=int, + default=2, + help="How many previous chunks to look if using pre-text for decoding", + ) + + parser.add_argument( + "--use-gt-pre-text", + type=str2bool, + default=False, + help="Whether use gt pre text when using content prompt", + ) + + parser.add_argument( + "--post-normalization", + type=str2bool, + default=True, + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + params.res_dir = params.exp_dir / "long_audio_transcribe" + params.res_dir.mkdir(exist_ok=True) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "beam_search" in params.method: + params.suffix += f"-{params.method}-beam-size-{params.beam_size}" + + if params.use_pre_text: + if params.use_gt_pre_text: + params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}" + else: + params.suffix += ( + f"-pre-text-{params.pre_text_transform}-history-{params.num_history}" + ) + + book_name = params.manifest_dir.split("/")[-1].replace(".jsonl.gz", "") + setup_logger( + f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info" + ) + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + tokenizer = get_tokenizer(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + # load manifest + manifest = load_manifest(params.manifest_dir) + + results = [] + count = 0 + + last_recording = "" + last_end = -1 + history = [] + num_pre_texts = [] + + for cut in manifest: + if cut.has_features: + feat = cut.load_features() + feat_lens = cut.num_frames + else: + feat = cut.compute_features(extractor=Fbank()) + feat_lens = feat.shape[0] + + cur_recording = cut.recording.id + + if cur_recording != last_recording: + last_recording = cur_recording + history = [] # clean up the history + last_end = -1 + logging.info("Moving on to the next recording") + else: + if cut.start < last_end - 0.2: # overlap with the previous cuts + logging.warning("An overlap exists between current cut and last cut") + logging.warning("Skipping this cut!") + continue + if cut.start > last_end + 10: + logging.warning( + f"Large time gap between the current and previous utterance: {cut.start - last_end}." + ) + + # prepare input + x = torch.tensor(feat, device=device).unsqueeze(0) + x_lens = torch.tensor( + [ + feat_lens, + ], + device=device, + ) + + if params.use_pre_text: + if params.num_history > 0: + pre_texts = history[-params.num_history :] + else: + pre_texts = [] + num_pre_texts.append(len(pre_texts)) + pre_texts = [train_text_normalization(" ".join(pre_texts))] + fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." + style_texts = [fixed_sentence] + + pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) + if params.use_style_prompt: + style_texts = _apply_style_transform( + style_texts, params.style_text_transform + ) + + # encode prompts + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( + pre_texts=pre_texts, + style_texts=style_texts, + tokenizer=tokenizer, + device=device, + no_limit=True, + ) + if params.num_history > 5: + logging.info( + f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} " + ) + + memory, memory_key_padding_mask = model.encode_text( + encoded_inputs=encoded_inputs, + style_lens=style_lens, + ) # (T,B,C) + else: + memory = None + memory_key_padding_mask = None + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + encoder_out, encoder_out_lens = model.encode_audio( + feature=x, + feature_lens=x_lens, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + + if params.method == "greedy_search": + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + hyp = sp.decode(hyp_tokens)[0] # in string format + ref_text = ref_text_normalization( + cut.supervisions[0].texts[0] + ) # required to match the training + + # extend the history + if params.use_gt_pre_text: + history.append(ref_text) + else: + history.append(hyp) + last_end = cut.end # update the last end timestamp + + # append the current decoding result + hyp = hyp.split() + ref = ref_text.split() + results.append((cut.id, ref, hyp)) + + count += 1 + if count % 100 == 0: + logging.info(f"Cuts processed until now: {count}/{len(manifest)}") + logging.info( + f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}" + ) + + logging.info(f"A total of {count} cuts") + logging.info( + f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}" + ) + + results = sorted(results) + recog_path = ( + params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + errs_filename = ( + params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, + f"long-audio-{params.method}", + results, + enable_log=True, + compute_CER=False, + ) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + if params.post_normalization: + params.suffix += "-post-normalization" + + new_res = [] + for item in results: + id, ref, hyp = item + hyp = upper_only_alpha(" ".join(hyp)).split() + ref = upper_only_alpha(" ".join(ref)).split() + new_res.append((id, ref, hyp)) + + new_res = sorted(new_res) + recog_path = ( + params.res_dir + / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt" + ) + store_transcripts(filename=recog_path, texts=new_res) + logging.info(f"The transcripts are stored in {recog_path}") + + errs_filename = ( + params.res_dir + / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, + f"long-audio-{params.method}", + new_res, + enable_log=True, + compute_CER=False, + ) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py b/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py new file mode 100644 index 000000000..533982519 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py @@ -0,0 +1,439 @@ +import argparse +import ast +import glob +import logging +import os +from collections import defaultdict +from typing import Dict, Iterable, List, TextIO, Tuple, Union + +import kaldialign +from lhotse import load_manifest, load_manifest_lazy +from lhotse.cut import Cut, CutSet +from text_normalization import remove_non_alphabetic +from tqdm import tqdm + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--manifest-dir", + type=str, + default="data/fbank", + help="Where are the manifest stored", + ) + + parser.add_argument( + "--subset", type=str, default="medium", help="Which subset to work with" + ) + + parser.add_argument( + "--top-k", + type=int, + default=10000, + help="How many words to keep", + ) + + return parser + + +def get_facebook_biasing_list( + test_set: str, + num_distractors: int = 100, +) -> Dict: + # Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf + assert num_distractors in (0, 100, 500, 1000, 2000), num_distractors + if num_distractors == 0: + if test_set == "test-clean": + biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_100.tsv" + elif test_set == "test-other": + biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_100.tsv" + else: + raise ValueError(f"Unseen test set {test_set}") + else: + if test_set == "test-clean": + biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv" + elif test_set == "test-other": + biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv" + else: + raise ValueError(f"Unseen test set {test_set}") + + f = open(biasing_file, "r") + data = f.readlines() + f.close() + + output = dict() + for line in data: + id, _, l1, l2 = line.split("\t") + if num_distractors > 0: # use distractors + biasing_list = ast.literal_eval(l2) + else: + biasing_list = ast.literal_eval(l1) + biasing_list = [w.strip().upper() for w in biasing_list] + output[id] = " ".join(biasing_list) + + return output + + +def brian_biasing_list(level: str): + # The biasing list from Brian's paper: https://arxiv.org/pdf/2109.00627.pdf + root_dir = f"data/context_biasing/LibriSpeechBiasingLists/{level}Level" + all_files = glob.glob(root_dir + "/*") + biasing_dict = {} + for f in all_files: + k = f.split("/")[-1] + fin = open(f, "r") + data = fin.read().strip().split() + biasing_dict[k] = " ".join(data) + fin.close() + + return biasing_dict + + +def get_rare_words( + subset: str = "medium", + top_k: int = 10000, + # min_count: int = 10000, +): + """Get a list of rare words appearing less than `min_count` times + + Args: + subset: The dataset + top_k (int): How many frequent words + """ + txt_path = f"data/tmp/transcript_words_{subset}.txt" + rare_word_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt" + + if os.path.exists(rare_word_file): + print("File exists, do not proceed!") + return + + print("---Identifying rare words in the manifest---") + count_file = f"data/tmp/transcript_words_{subset}_count.txt" + if not os.path.exists(count_file): + with open(txt_path, "r") as file: + words = file.read().upper().split() + word_count = {} + for word in words: + word = remove_non_alphabetic(word, strict=False) + word = word.split() + for w in word: + if w not in word_count: + word_count[w] = 1 + else: + word_count[w] += 1 + + word_count = list(word_count.items()) # convert to a list of tuple + word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True) + with open(count_file, "w") as fout: + for w, count in word_count: + fout.write(f"{w}\t{count}\n") + + else: + word_count = {} + with open(count_file, "r") as fin: + word_count = fin.read().strip().split("\n") + word_count = [pair.split("\t") for pair in word_count] + word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True) + + print(f"A total of {len(word_count)} words appeared!") + rare_words = [] + for word, count in word_count[top_k:]: + rare_words.append(word + "\n") + print(f"A total of {len(rare_words)} are identified as rare words.") + + with open(rare_word_file, "w") as f: + f.writelines(rare_words) + + +def add_context_list_to_manifest( + manifest_dir: str, + subset: str = "medium", + top_k: int = 10000, +): + """Generate a context list of rare words for each utterance in the manifest + + Args: + manifest_dir: Where to store the manifest with context list + subset (str): Subset + top_k (int): How many frequent words + + """ + orig_manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}.jsonl.gz" + target_manifest_dir = orig_manifest_dir.replace( + ".jsonl.gz", f"_with_context_list_topk_{top_k}.jsonl.gz" + ) + if os.path.exists(target_manifest_dir): + print(f"Target file exits at {target_manifest_dir}!") + return + + rare_words_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt" + print(f"---Reading rare words from {rare_words_file}---") + with open(rare_words_file, "r") as f: + rare_words = f.read() + rare_words = rare_words.split("\n") + rare_words = set(rare_words) + print(f"A total of {len(rare_words)} rare words!") + + cuts = load_manifest_lazy(orig_manifest_dir) + print(f"Loaded manifest from {orig_manifest_dir}") + + def _add_context(c: Cut): + splits = ( + remove_non_alphabetic(c.supervisions[0].texts[0], strict=False) + .upper() + .split() + ) + found = [] + for w in splits: + if w in rare_words: + found.append(w) + c.supervisions[0].context_list = " ".join(found) + return c + + cuts = cuts.map(_add_context) + print(f"---Saving manifest with context list to {target_manifest_dir}---") + cuts.to_file(target_manifest_dir) + print("Finished") + + +def check( + manifest_dir: str, + subset: str = "medium", + top_k: int = 10000, +): + # Show how many samples in the training set have a context list + # and the average length of context list + print("--- Calculating the stats over the manifest ---") + + manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}_with_context_list_topk_{top_k}.jsonl.gz" + cuts = load_manifest_lazy(manifest_dir) + total_cuts = len(cuts) + has_context_list = [c.supervisions[0].context_list != "" for c in cuts] + context_list_len = [len(c.supervisions[0].context_list.split()) for c in cuts] + print(f"{sum(has_context_list)}/{total_cuts} cuts have context list! ") + print( + f"Average length of non-empty context list is {sum(context_list_len)/sum(has_context_list)}" + ) + + +def write_error_stats( + f: TextIO, + test_set_name: str, + results: List[Tuple[str, str]], + enable_log: bool = True, + compute_CER: bool = False, + biasing_words: List[str] = None, +) -> float: + """Write statistics based on predicted results and reference transcripts. It also calculates the + biasing word error rate as described in https://arxiv.org/pdf/2104.02194.pdf + + It will write the following to the given file: + + - WER + - number of insertions, deletions, substitutions, corrects and total + reference words. For example:: + + Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 + reference words (2337 correct) + + - The difference between the reference transcript and predicted result. + An instance is given below:: + + THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES + + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). + + Another example is:: + + FOR THE FIRST DAY (SIR->*) I THINK + + The reference word `SIR` is missing in the predicted + results (a deletion error). + results: + An iterable of tuples. The first element is the cut_id, the second is + the reference transcript and the third element is the predicted result. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. + biasing_words: + All the words in the biasing list + Returns: + Return None. + """ + subs: Dict[Tuple[str, str], int] = defaultdict(int) + ins: Dict[str, int] = defaultdict(int) + dels: Dict[str, int] = defaultdict(int) + + # `words` stores counts per word, as follows: + # corr, ref_sub, hyp_sub, ins, dels + words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) + num_corr = 0 + ERR = "*" + + if compute_CER: + for i, res in enumerate(results): + cut_id, ref, hyp = res + ref = list("".join(ref)) + hyp = list("".join(hyp)) + results[i] = (cut_id, ref, hyp) + + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + for ref_word, hyp_word in ali: + if ref_word == ERR: # INSERTION + ins[hyp_word] += 1 + words[hyp_word][3] += 1 + elif hyp_word == ERR: # DELETION + dels[ref_word] += 1 + words[ref_word][4] += 1 + elif hyp_word != ref_word: # SUBSTITUTION + subs[(ref_word, hyp_word)] += 1 + words[ref_word][1] += 1 + words[hyp_word][2] += 1 + else: + words[ref_word][0] += 1 + num_corr += 1 + ref_len = sum([len(r) for _, r, _ in results]) + sub_errs = sum(subs.values()) + ins_errs = sum(ins.values()) + del_errs = sum(dels.values()) + tot_errs = sub_errs + ins_errs + del_errs + tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + + if enable_log: + logging.info( + f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]" + ) + + print(f"%WER = {tot_err_rate}", file=f) + print( + f"Errors: {ins_errs} insertions, {del_errs} deletions, " + f"{sub_errs} substitutions, over {ref_len} reference " + f"words ({num_corr} correct)", + file=f, + ) + print( + "Search below for sections starting with PER-UTT DETAILS:, " + "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", + file=f, + ) + + print("", file=f) + print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + combine_successive_errors = True + if combine_successive_errors: + ali = [[[x], [y]] for x, y in ali] + for i in range(len(ali) - 1): + if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: + ali[i + 1][0] = ali[i][0] + ali[i + 1][0] + ali[i + 1][1] = ali[i][1] + ali[i + 1][1] + ali[i] = [[], []] + ali = [ + [ + list(filter(lambda a: a != ERR, x)), + list(filter(lambda a: a != ERR, y)), + ] + for x, y in ali + ] + ali = list(filter(lambda x: x != [[], []], ali)) + ali = [ + [ + ERR if x == [] else " ".join(x), + ERR if y == [] else " ".join(y), + ] + for x, y in ali + ] + + print( + f"{cut_id}:\t" + + " ".join( + ( + ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" + for ref_word, hyp_word in ali + ) + ), + file=f, + ) + + print("", file=f) + print("SUBSTITUTIONS: count ref -> hyp", file=f) + + for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): + print(f"{count} {ref} -> {hyp}", file=f) + + print("", file=f) + print("DELETIONS: count ref", file=f) + for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): + print(f"{count} {ref}", file=f) + + print("", file=f) + print("INSERTIONS: count hyp", file=f) + for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): + print(f"{count} {hyp}", file=f) + + unbiased_word_counts = 0 + unbiased_word_errs = 0 + biased_word_counts = 0 + biased_word_errs = 0 + + print("", file=f) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) + + for _, word, counts in sorted( + [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True + ): + (corr, ref_sub, hyp_sub, ins, dels) = counts + tot_errs = ref_sub + hyp_sub + ins + dels + # number of appearances of "word" in reference text + ref_count = ( + corr + ref_sub + dels + ) # correct + in ref but got substituted + deleted + # number of appearances of "word" in hyp text + hyp_count = corr + hyp_sub + ins + + if biasing_words is not None: + if word in biasing_words: + biased_word_counts += ref_count + biased_word_errs += ins + dels + ref_sub + else: + unbiased_word_counts += ref_count + unbiased_word_errs += ins + dels + hyp_sub + + print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) + + if biasing_words is not None: + B_WER = "%.2f" % (100 * biased_word_errs / biased_word_counts) + U_WER = "%.2f" % (100 * unbiased_word_errs / unbiased_word_counts) + logging.info(f"Biased WER: {B_WER} [{biased_word_errs}/{biased_word_counts}] ") + logging.info( + f"Un-biased WER: {U_WER} [{unbiased_word_errs}/{unbiased_word_counts}]" + ) + + return float(tot_err_rate) + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + manifest_dir = args.manifest_dir + subset = args.subset + top_k = args.top_k + get_rare_words(subset=subset, top_k=top_k) + add_context_list_to_manifest( + manifest_dir=manifest_dir, + subset=subset, + top_k=top_k, + ) + check( + manifest_dir=manifest_dir, + subset=subset, + top_k=top_k, + ) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py b/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py new file mode 100644 index 000000000..d1cf90ffb --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py @@ -0,0 +1,2310 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +) +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationDropoutAndLinear, + Balancer, + BiasNorm, + ChunkCausalDepthwiseConv1d, + Dropout2, + FloatLike, + ScheduledFloat, + Whiten, + convert_num_channels, + limit_param_value, + penalize_abs_values_gt, + softmax, +) +from torch import Tensor, nn + + +class Zipformer2(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of + the encoder stacks for purposes of per-frame dropout (recommend 256 for + now). + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + value_head_dim (int or Tuple[int]): dimension of value in each attention head + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per + attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + pos_dim (int): the dimension of each positional-encoding vector prior to projection, + e.g. 128. + + dropout (float): dropout rate + warmup_batches (float): number of batches to warm up over; this controls + dropout of encoder layers. + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. + memory_dim: if supplied and >0, will be the dimension of the memory embeddings + passed into the zipformer (e.g. this might be the output of another + Zipformer used to create embedding vectors.) + memory_dropout_rate: By this probability, do not use the provided memory for + cross-attention. This should give robustness to the model when evaluated + without memory. + memory_layer: if supplied and >0, only add cross-attention module starting from + the specified layer. + """ + + def __init__( + self, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + encoder_unmasked_dim: Union[int, Tuple[int]] = 256, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_dim: Union[int, Tuple[int]] = 1536, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + memory_dim: int = -1, + memory_dropout_rate: float = 0.05, + memory_layer: int = 0, + ) -> None: + super(Zipformer2, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( + encoder_unmasked_dim + ) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + pos_head_dim = _to_tuple(pos_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_dim = _to_tuple(feedforward_dim) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + self.memory_dropout_rate = memory_dropout_rate + self.memory_layer = memory_layer + + for u, d in zip(encoder_unmasked_dim, encoder_dim): + assert u <= d + + # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + encoders = [] + + num_encoders = len(downsampling_factor) + for i in range(num_encoders): + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + pos_head_dim=pos_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + memory_dim=memory_dim if i >= self.memory_layer else -1, + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + pos_dim=pos_dim, + dropout=dropout, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + ) + + if downsampling_factor[i] != 1: + encoder = DownsampledZipformer2Encoder( + encoder, + dim=encoder_dim[i], + downsample=downsampling_factor[i], + dropout=dropout, + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + self.downsample_output = SimpleDownsample( + max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout + ) + + def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (1, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dim) + if not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dim[0] == _encoder_dims0, ( + self.encoder_dim[0], + _encoder_dims0, + ) + + feature_mask_dropout_prob = 0.125 + + # mask1 shape: (1, batch_size, 1) + mask1 = ( + torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob + ).to(x.dtype) + + # mask2 has additional sequences masked, about twice the number. + mask2 = torch.logical_and( + mask1, + ( + torch.rand(1, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype), + ) + + # dim: (1, batch_size, 2) + mask = torch.cat((mask1, mask2), dim=-1) + + feature_masks = [] + for i in range(num_encoders): + channels = self.encoder_dim[i] + feature_mask = torch.ones( + 1, batch_size, channels, dtype=x.dtype, device=x.device + ) + u1 = self.encoder_unmasked_dim[i] + u2 = u1 + (channels - u1) // 2 + + feature_mask[:, :, u1:u2] *= mask[..., 0:1] + feature_mask[:, :, u2:] *= mask[..., 1:2] + + feature_masks.append(feature_mask) + + return feature_masks + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + src_key_padding_mask: Optional[torch.Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) + memory_key_padding_mask: optionally the mask for padding of memory input (for source- + attention), of shape (batch_size, memory_len); True means + masked position. May be None. + + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + outputs = [] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + feature_masks = [1.0] * len(self.encoder_dim) + else: + feature_masks = self.get_feature_masks(x) + + chunk_size, left_context_chunks = self.get_chunk_info() + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + if self.training and memory is not None: + batch_size = x.shape[1] + # setting memory to zero should be equivalent to not using the + # memory input at all, since the Attention module has no biases. + memory = memory * ( + torch.rand(batch_size, 1, device=memory.device) + > self.memory_dropout_rate + ) + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x = module( + x, + chunk_size=chunk_size, + feature_mask=feature_masks[i], + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=attn_mask, + memory=memory if i >= self.memory_layer else None, + memory_key_padding_mask=memory_key_padding_mask + if i >= self.memory_layer + else None, + ) + outputs.append(x) + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths + + def _get_attn_mask( + self, x: Tensor, chunk_size: int, left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all( + chunk_size * left_context_chunks + >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + def _get_full_dim_output(self, outputs: List[Tensor]): + num_encoders = len(self.encoder_dim) + assert len(outputs) == num_encoders + output_dim = max(self.encoder_dim) + output_pieces = [outputs[-1]] + cur_dim = self.encoder_dim[-1] + for i in range(num_encoders - 2, -1, -1): + d = self.encoder_dim[i] + if d > cur_dim: + this_output = outputs[i] + output_pieces.append(this_output[..., cur_dim:d]) + cur_dim = d + assert cur_dim == output_dim + return torch.cat(output_pieces, dim=-1) + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + states: List[Tensor], + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states + """ + outputs = [] + new_states = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + outputs.append(x) + new_states += new_layer_states + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + + A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + """ + states = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + nonlin_attn_head_dim = 3 * embed_dim // 4 + conv_left_pad = self.cnn_module_kernel[i] // 2 + for layer in range(num_layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( + device + ) + cached_nonlin_attn = torch.zeros( + 1, batch_size, downsample_left, nonlin_attn_head_dim + ).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + return states + + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) + + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + causal: bool = False, + memory_dim: int = -1, + attention_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + conv_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + const_attention_rate: FloatLike = ScheduledFloat( + (0.0, 0.25), (4000.0, 0.025), default=0 + ), + ff2_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + ff3_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + bypass_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.5), (4000.0, 0.02), default=0 + ), + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + + # self.bypass implements layer skipping as well as bypass; see its default values. + self.bypass = BypassModule( + embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 + ) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + + # skip probability for dynamic modules (meaning: anything but feedforward). + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + # an additional skip probability that applies to ConvModule to stop it from + # contributing too much early on. + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + + self.const_attention_rate = copy.deepcopy(const_attention_rate) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + dropout=0.0, + ) + + self.self_attn1 = Attention(embed_dim, embed_dim, num_heads, value_head_dim) + + self.self_attn2 = Attention(embed_dim, embed_dim, num_heads, value_head_dim) + + if memory_dim > 0: + self.attn_weights = MultiheadAttentionWeights( + memory_dim, + embed_dim, + num_heads=num_heads, + head_dim=query_head_dim, + dropout=0.0, + ) + self.src_attn1 = Attention(memory_dim, embed_dim, num_heads, value_head_dim) + self.src_attn2 = Attention(memory_dim, embed_dim, num_heads, value_head_dim) + self.memory_balancer = Balancer( + embed_dim, + channel_dim=-1, + min_abs=0.015, + ) + + self.feed_forward1 = FeedforwardModule( + embed_dim, (feedforward_dim * 3) // 4, dropout + ) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule( + embed_dim, (feedforward_dim * 5) // 4, dropout + ) + + self.nonlin_attention = NonlinAttention( + embed_dim, hidden_channels=3 * embed_dim // 4 + ) + + self.conv_module1 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + self.conv_module2 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + # self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) + + self.norm = BiasNorm(embed_dim) + + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + + self.balancer1 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.2, + max_abs=4.0, + ) + + # balancer for output of NonlinAttentionModule + self.balancer_na = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + + # balancer for output of feedforward2, prevent it from staying too + # small. give this a very small probability, even at the start of + # training, it's to fix a rare problem and it's OK to fix it slowly. + self.balancer_ff2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), + max_abs=2.0, + prob=0.05, + ) + + self.balancer_ff3 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + max_abs=4.0, + prob=0.05, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.balancer2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.1, + max_abs=4.0, + ) + + def get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 correponds to bypassing + # this module. + if torch.jit.is_scripting() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value( + self.bypass_scale, + min=float(self.bypass_min), + max=float(self.bypass_max), + ) + layer_skip_rate = float(self.layer_skip_rate) + if layer_skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > layer_skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + return ans + + def get_sequence_dropout_mask( + self, x: Tensor, dropout_rate: float + ) -> Optional[Tensor]: + if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting(): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + return mask + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # dropout rate for non-feedforward submodules + attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + + if memory is not None and hasattr(self, "attn_weights"): + src_attn_weights = self.attn_weights(memory, src, memory_key_padding_mask) + + src = src + self.feed_forward1(src) + + attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) + + if True: + selected_attn_weights = attn_weights[0:2] + if random.random() < float(self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + selected_attn_weights = selected_attn_weights[0:1] + selected_attn_weights = (selected_attn_weights > 0.0).to( + selected_attn_weights.dtype + ) + selected_attn_weights = selected_attn_weights * ( + 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) + ) + selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1) + + na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights[0:1])) + + src = src + (na if attn_dropout_mask is None else na * attn_dropout_mask) + + self_attn = self.self_attn1(src, attn_weights) + + src = src + ( + self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask + ) + + if memory is not None and hasattr(self, "attn_weights"): + src = src + self.sequence_dropout( + self.memory_balancer(self.src_attn1(memory, src_attn_weights)), + attention_skip_rate, + ) + + src = src + self.sequence_dropout( + self.conv_module1( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + float(self.conv_skip_rate), + ) + + src = src + self.sequence_dropout( + self.balancer_ff2(self.feed_forward2(src)), float(self.ff2_skip_rate) + ) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn = self.self_attn2(src, attn_weights) + + src = src + ( + self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask + ) + + if memory is not None and hasattr(self, "attn_weights"): + src = src + self.sequence_dropout( + self.memory_balancer(self.src_attn2(memory, src_attn_weights)), + attention_skip_rate, + ) + + src = src + self.sequence_dropout( + self.conv_module2( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + float(self.conv_skip_rate), + ) + + src = src + self.sequence_dropout( + self.balancer_ff3(self.feed_forward3(src)), float(self.ff3_skip_rate) + ) + + src = self.balancer1(src) + src = self.norm(src) + + src = self.bypass(src_orig, src) + + src = self.balancer2(src) + src = self.whiten(src) + + return src + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + pos_dim: the dimension for the relative positional encoding + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + pos_dim: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, + ) -> None: + super().__init__() + self.encoder_pos = CompactRelPositionalEncoding( + pos_dim, dropout_rate=0.15, length_factor=1.0 + ) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert 0 <= warmup_begin <= warmup_end + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin # interpreted as a training batch index + for i in range(num_layers): + cur_end = cur_begin + delta + self.layers[i].bypass.skip_rate = ScheduledFloat( + (cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0, + ) + cur_begin = cur_end + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) + memory_key_padding_mask: optionally the mask for padding of memory input (for source- + attention), of shape (batch_size, memory_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + pos_emb = self.encoder_pos(src) + output = src + + output = output * feature_mask + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + + output = output * feature_mask + + return output + + +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0, + ): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 correponds to bypassing + # this module. + if torch.jit.is_scripting() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value( + self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) + ) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + mask = ( + torch.rand((batch_size, 1), device=ans.device) + < straight_through_rate + ) + ans = torch.maximum(ans, mask.to(ans.dtype)) + + return ans + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + bypass_scale = self._get_bypass_scale(src.shape[1]) + return src_orig + (src - src_orig) * bypass_scale + + +class DownsampledZipformer2Encoder(nn.Module): + r""" + DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__( + self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike + ): + super(DownsampledZipformer2Encoder, self).__init__() + self.downsample_factor = downsample + self.downsample = SimpleDownsample(dim, downsample, dropout) + self.encoder = encoder + self.upsample = SimpleUpsample(dim, downsample) + self.out_combiner = BypassModule(dim, straight_through_rate=0.025) + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) + memory_key_padding_mask: optionally the mask for padding of memory input (for source- + attention), of shape (batch_size, memory_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if attn_mask is not None: + attn_mask = attn_mask[::ds, ::ds] + + src = self.encoder( + src, + chunk_size=chunk_size // ds, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + +class SimpleDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__(self, channels: int, downsample: int, dropout: FloatLike): + super(SimpleDownsample, self).__init__() + + self.bias = nn.Parameter(torch.zeros(downsample)) + + self.name = None # will be set from training code + self.dropout = copy.deepcopy(dropout) + + self.downsample = downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + if seq_len != d_seq_len * ds: + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.upsample = upsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.upsample + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0 + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0 + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= x.size(0) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + T = x.size(0) + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> Tensor: + """Create positional encoding. + + Args: + x (torch.Tensor): Input tensor (time, batch, `*`). + + Returns: + positional embedding, of shape (1, 2*time-1, `*`). + + """ + self.extend_pe(x) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x.size(0) + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim + self.dropout = dropout + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + ) + + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # add a balancer for the keys that runs with very small probability, and + # tries to enforce that all dimensions have mean around zero. The + # weights produced by this module are invariant to adding a constant to + # the keys, so the derivative of the bias is mathematically zero; but + # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero + # bias because the small numerical roundoff tends to have a non-random + # sign. This module is intended to prevent that. Use a very small + # probability; that should be suffixient to fix the problem. + self.balance_keys = Balancer( + key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnosics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 2, pos_dim) + chunk_size + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + if not self.training or random.random() >= float(self.pos_emb_skip_rate): + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + if self.training and random.random() < 0.1: + # This is away of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 25.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + ) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if random.random() < 0.001: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class Attention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim_in: the input embedding dimension + embed_dim_out: the output embedding dimension (normally the same as input) + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim_in: int, + embed_dim_out: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim_in, num_heads * value_head_dim, bias=False) + + # Note we set bias to False so that input of 0 will have no effect + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim_out, bias=False, initial_scale=0.05 + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, query_len, key_len), + Expect attn_weights.sum(dim=-1) == 1. The input here is the value in the + original attention mechanism. + Returns: + a tensor with the same shape as x. + """ + (num_heads, batch_size, query_len, key_len) = attn_weights.shape + + x = self.in_proj(x) # (key_len, batch_size, num_heads * value_head_dim) + x = x.reshape(key_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, key_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, query_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(query_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (query_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val + + +class MultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head cross-attention weights. Allows src and target + to have different dims. + + Args: + key_embed_dim: number of channels of the thing that we'll project to + make the query (corresponds to source). e.g. 256 + query_embed_dim: number of channels of the thing that we'll project to + make the query (corresponds to target). e.g. 256 + num_heads: number of heads to compute weights for, e.g. 8 + head_dim: dimension of the query and key, per head. e.g. 24. + dropout: dropout probability for attn_output_weights. Default: 0.0. + """ + + def __init__( + self, + key_embed_dim: int, + query_embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.key_embed_dim = key_embed_dim + self.query_embed_dim = query_embed_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.dropout = dropout + self.name = None # will be overwritten in training code; for diagnostics. + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.query_in_proj = ScaledLinear( + query_embed_dim, + head_dim * num_heads, + bias=True, + initial_scale=head_dim**-0.25, + ) + + # weights produced by this module are invariant to adding a constant to + # the keys, so we don't need a bias for the keys. + self.key_in_proj = ScaledLinear( + key_embed_dim, + head_dim * num_heads, + bias=False, + initial_scale=head_dim**-0.25, + ) + + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + def forward( + self, + key: Tensor, + query: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + key: input of shape (key_len, batch_size, key_embed_dim) + query: input of shape (query_len, batch_size, query_embed_dim) + key_padding_mask: an optional bool tensor of shape (batch_size, key_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, query_len, key_len) + """ + q = self.query_in_proj(query) + k = self.key_in_proj(key) + + head_dim = self.head_dim + num_heads = self.num_heads + + query_len, batch_size, _ = q.shape + key_len, _batch_size, _ = k.shape + assert _batch_size == batch_size + + k = self.whiten_keys(k) # does nothing in the forward pass. + + q = q.reshape(query_len, batch_size, num_heads, head_dim) + k = k.reshape(key_len, batch_size, num_heads, head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + if self.training and random.random() < 0.1: + # This is a way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 25.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + ) + + assert attn_scores.shape == (num_heads, batch_size, query_len, key_len) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + key_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if random.random() < 0.001: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer2 model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(embed_dim, feedforward_dim) + + self.hidden_balancer = Balancer( + feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwooshL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.1, + ) + + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.hidden_balancer(x) + # out_proj contains SwooshL activation, then dropout, then linear. + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, + # because we noticed that well-trained instances of this module have abs-value before the sigmoid + # starting from about 3, and poorly-trained instances of the module have smaller abs values + # before the sigmoid. + self.balancer = Balancer( + hidden_channels, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), + max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), + min_abs=0.5, + max_abs=5.0, + ) + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) + + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_x: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + cached_x: left context, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + left_context_len: number of left context frames. + Returns: + - a Tensor with the same shape as x + - updated left context with same shape as cached_x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = x * s + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == ( + num_heads, + batch_size, + seq_len, + left_context_len + seq_len, + ) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + + # Pad cached tensor + assert cached_x.shape[2] == left_context_len, ( + cached_x.shape[2], + left_context_len, + ) + x_pad = torch.cat([cached_x, x], dim=2) + # Update cached tensor + cached_x = x_pad[:, :, -left_context_len:, :] + + x = torch.matmul(attn_weights, x_pad) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + x = x * y + + x = self.out_proj(x) + return x, cached_x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + # after in_proj we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in + # the range 1 to 4, but sometimes, for some reason, for layer 0 the + # rms ends up being very large, between 50 and 100 for different channels. + # This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.balancer1 = Balancer( + bottleneck_dim, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), + max_positive=1.0, + min_abs=1.5, + max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), + ) + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + if causal + else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + + self.balancer2 = Balancer( + bottleneck_dim, + channel_dim=1, + min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + max_positive=1.0, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), + max_abs=10.0, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation="SwooshR", + dropout_p=0.0, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=-1) + s = self.balancer1(s) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and chunk_size >= 0 + ): + # Not support exporting a model for simulated streaming decoding + assert ( + self.causal + ), "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) + + x = self.balancer2(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module in streaming forward mode. + + Args: + x: Input tensor (#time, batch, channels). + cache: cached left context for depthwise_conv of shape + (#batch, channels, left_pad) + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cache (#batch, channels, left_pad) + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = x * s + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + return x, cache + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + memory_dim = 100 + + c = Zipformer2( + encoder_dim=(64, 96), + encoder_unmasked_dim=(48, 64), + num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,), + memory_dim=memory_dim, + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + memory=torch.randn(101, batch_size, memory_dim), + ) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True) diff --git a/icefall/utils.py b/icefall/utils.py index 8fda3a4ca..410340d9d 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -483,7 +483,7 @@ def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]: def store_transcripts( - filename: Pathlike, texts: Iterable[Tuple[str, str, str]] + filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False ) -> None: """Save predicted results and reference transcripts to a file. @@ -500,6 +500,9 @@ def store_transcripts( """ with open(filename, "w") as f: for cut_id, ref, hyp in texts: + if char_level: + ref = list("".join(ref)) + hyp = list("".join(hyp)) print(f"{cut_id}:\tref={ref}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f) @@ -557,6 +560,7 @@ def write_error_stats( test_set_name: str, results: List[Tuple[str, str]], enable_log: bool = True, + compute_CER: bool = False, sclite_mode: bool = False, ) -> float: """Write statistics based on predicted results and reference transcripts. @@ -585,7 +589,7 @@ def write_error_stats( The reference word `SIR` is missing in the predicted results (a deletion error). results: - An iterable of tuples. The first element is the cur_id, the second is + An iterable of tuples. The first element is the cut_id, the second is the reference transcript and the third element is the predicted result. enable_log: If True, also print detailed WER to the console. @@ -602,6 +606,14 @@ def write_error_stats( words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) num_corr = 0 ERR = "*" + + if compute_CER: + for i, res in enumerate(results): + cut_id, ref, hyp = res + ref = list("".join(ref)) + hyp = list("".join(hyp)) + results[i] = (cut_id, ref, hyp) + for cut_id, ref, hyp in results: ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) for ref_word, hyp_word in ali: @@ -1426,7 +1438,10 @@ def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, floa def get_parameter_groups_with_lrs( - model: nn.Module, lr: float, include_names: bool = False + model: nn.Module, + lr: float, + include_names: bool = False, + freeze_modules: List[str] = [], ) -> List[dict]: """ This is for use with the ScaledAdam optimizers (more recent versions that accept lists of @@ -1450,6 +1465,8 @@ def get_parameter_groups_with_lrs( ... ] """ + named_modules = list(model.named_modules()) + # flat_lr_scale just contains the lr_scale explicitly specified # for each prefix of the name, e.g. 'encoder.layers.3', these need # to be multiplied for all prefix of the name of any given parameter. @@ -1469,6 +1486,15 @@ def get_parameter_groups_with_lrs( split_name = name.split(".") # caution: as a special case, if the name is '', split_name will be [ '' ]. prefix = split_name[0] + if prefix == "module": # DDP + module_name = split_name[1] + if module_name in freeze_modules: + logging.info(f"Remove {name} from parameters") + continue + else: + if prefix in freeze_modules: + logging.info(f"Remove {name} from parameters") + continue cur_lr = lr * flat_lr_scale[prefix] if prefix != "": cur_lr *= flat_lr_scale[""] From 2b3c5d799f3a585dc22071a9148424ff77aefd47 Mon Sep 17 00:00:00 2001 From: Wen Ding Date: Wed, 11 Oct 2023 16:58:00 +0800 Subject: [PATCH 063/113] Fix padding issues (#1303) --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bcd419fb7..ab46e233b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -158,7 +158,7 @@ class Conformer(EncoderInterface): if not is_jit_tracing(): assert x.size(0) == lengths.max().item() - src_key_padding_mask = make_pad_mask(lengths) + src_key_padding_mask = make_pad_mask(lengths, x.size(0)) if self.dynamic_chunk_training: assert ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 5b75b8d35..cbde2a2e4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -281,7 +281,7 @@ class Zipformer(EncoderInterface): lengths = (x_lens - 7) >> 1 assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) - mask = make_pad_mask(lengths) + mask = make_pad_mask(lengths, x.size(0)) outputs = [] feature_masks = self.get_feature_masks(x) From 855492156a3c84bca67870d808d033fe963f16bf Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 12 Oct 2023 16:48:23 +0800 Subject: [PATCH 064/113] Update finetune.py (#1304) --- egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index 82bc882bd..c943a84af 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -734,7 +734,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch(batch, params=params) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: From 162ceaf4b3110d452b5fed337d721c046d7787fa Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 12 Oct 2023 17:05:41 +0800 Subject: [PATCH 065/113] fixes for data preparation (#1307) Issue: #1306 --- egs/aishell/ASR/prepare.sh | 11 +++++++---- egs/librispeech/ASR/prepare.sh | 14 ++++++++++---- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 9de060e73..d5dbe5726 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -204,10 +204,6 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ ! -f $lang_char_dir/L_disambig.pt ]; then ./local/prepare_char.py --lang-dir $lang_char_dir fi - - if [ ! -f $lang_char_dir/HLG.fst ]; then - ./local/prepare_lang_fst.py --lang-dir $lang_phone_dir --ngram-G ./data/lm/G_3_gram.fst.txt - fi fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then @@ -262,6 +258,13 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then --max-order=3 \ data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt fi + + if [ ! -f $lang_char_dir/HLG.fst ]; then + lang_phone_dir=data/lang_phone + ./local/prepare_lang_fst.py \ + --lang-dir $lang_phone_dir \ + --ngram-G ./data/lm/G_3_gram.fst.txt + fi fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 93d010ea8..739608572 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -242,10 +242,6 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then $lang_dir/L_disambig.pt \ $lang_dir/L_disambig.fst fi - - if [ ! -f $lang_dir/HL.fst ]; then - ./local/prepare_lang_fst.py --lang-dir $lang_dir --ngram-G ./data/lm/G_3_gram.fst.txt - fi done fi @@ -303,6 +299,16 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then --max-order=4 \ $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt fi + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/HL.fst ]; then + ./local/prepare_lang_fst.py \ + --lang-dir $lang_dir \ + --ngram-G ./data/lm/G_3_gram.fst.txt + fi + done fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then From eeeeef390b2d7f1aefe742ac069565d5f8eb8a38 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 12 Oct 2023 22:02:49 +0800 Subject: [PATCH 066/113] Minor bug fixes and descriptive text for the `LibriCSS` recipe (#1268) --- egs/libricss/SURT/prepare.sh | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/egs/libricss/SURT/prepare.sh b/egs/libricss/SURT/prepare.sh index 3d2581d96..b2d37f949 100755 --- a/egs/libricss/SURT/prepare.sh +++ b/egs/libricss/SURT/prepare.sh @@ -90,6 +90,14 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # NOTE: Alignments are required for this recipe. mkdir -p data/manifests + log "This recipe uses mfa alignment for trimming" + if [ ! -d $dl_dir/libri_alignments/LibriSpeech ]; then + log "No alignment provided. please refer to ../../librispeech/ASR/add_alignments.sh \n \ + for mfa alignments. Once you have downloaded and unzipped the .zip file containing \n \ + all alignments, the folder should be renamed to libri_alignments and moved to your $dl_dir ." + exit 0 + fi + lhotse prepare librispeech -p train-clean-100 -p train-clean-360 -p train-other-500 -p dev-clean \ -j 4 --alignments-dir $dl_dir/libri_alignments/LibriSpeech $dl_dir/librispeech data/manifests/ fi @@ -118,9 +126,12 @@ fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Extract features for LibriSpeech, trim to alignments, and shuffle the cuts" - python local/compute_fbank_librispeech.py - lhotse combine data/manifests/librispeech_cuts_train* - |\ - lhotse cut trim-to-alignments --type word --max-pause 0.2 - - |\ + # python local/compute_fbank_librispeech.py + lhotse combine data/manifests/librispeech_cuts_train* data/manifests/librispeech_cuts_train_all.jsonl.gz + lhotse cut trim-to-alignments --type word --max-pause 0.2 \ + data/manifests/librispeech_cuts_train_all.jsonl.gz \ + data/manifests/librispeech_cuts_train_all_trimmed.jsonl.gz + cat <(gunzip -c data/manifests/librispeech_cuts_train_all_trimmed.jsonl.gz) | \ shuf | gzip -c > data/manifests/librispeech_cuts_train_trimmed.jsonl.gz fi @@ -152,7 +163,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then data/manifests/lsmix_cuts_train_clean_ov40.jsonl.gz # Full training set (2,3 speakers) anechoic - log "Generating anechoic ${part} set (full)" + log "Generating anechoic set (full)" lhotse workflows simulate-meetings \ --method conversational \ --fit-to-supervisions data/manifests/libricss-sdm_supervisions_all_v1.jsonl.gz \ From 1ef349d120acef5d48feee58c4462a56f4a8c995 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 16 Oct 2023 16:28:16 +0800 Subject: [PATCH 067/113] [WIP] AISHELL-1 pruned transducer stateless7 streaming recipe (#1300) * `pruned_transudcer_stateless7_streaming` for AISHELL-1 * Update train.py * Update train2.py * Update decode.py * Update RESULTS.md --- egs/aishell/ASR/RESULTS.md | 50 + .../README.md | 1 + .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../decode.py | 735 ++++++++++ .../decode_stream.py | 1 + .../decoder.py | 1 + .../encoder_interface.py | 1 + .../export-for-ncnn-zh.py | 1 + .../export-for-ncnn.py | 1 + .../export-onnx-zh.py | 1 + .../export-onnx.py | 1 + .../export.py | 1 + .../jit_pretrained.py | 1 + .../jit_trace_export.py | 1 + .../jit_trace_pretrained.py | 1 + .../joiner.py | 1 + .../model.py | 1 + .../ncnn_custom_layer.py | 1 + .../onnx_check.py | 1 + .../onnx_model_wrapper.py | 1 + .../onnx_pretrained.py | 1 + .../optim.py | 1 + .../pretrained.py | 1 + .../scaling.py | 1 + .../scaling_converter.py | 1 + .../streaming-ncnn-decode.py | 1 + .../streaming_beam_search.py | 1 + .../streaming_decode.py | 627 +++++++++ .../test_model.py | 1 + .../train.py | 1251 ++++++++++++++++ .../train2.py | 1253 +++++++++++++++++ .../zipformer.py | 1 + .../zipformer2.py | 1 + 34 files changed, 3945 insertions(+) create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/beam_search.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode_stream.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/decoder.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_check.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/pretrained.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/test_model.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 5088497a1..a2d32013a 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -2,6 +2,56 @@ ### Aishell training result(Stateless Transducer) +#### Pruned transducer stateless 7 streaming +[./pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) + +It's Streaming version of Zipformer1 with Pruned RNNT loss. + +| | test | dev | comment | +|------------------------|------|------|---------------------------------------| +| greedy search | 6.95 | 6.29 | --epoch 44 --avg 15 --max-duration 600 | +| modified beam search | 6.51 | 5.90 | --epoch 44 --avg 15 --max-duration 600 | +| fast beam search | 6.73 | 6.09 | --epoch 44 --avg 15 --max-duration 600 | + +Training command is: + +```bash +./prepare.sh + +export CUDA_VISIBLE_DEVICES="0,1" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --use-fp16 1 \ + --context-size 1 \ + --max-duration 800 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --enable-musan 0 \ + --spec-aug-time-warp-factor 20 +``` + +**Caution**: It uses `--context-size=1`. + +The decoding command is: +```bash +for m in greedy_search modified_beam_search fast_beam_search ; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 44 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang-dir data/lang_char \ + --context-size 1 \ + --decoding-method $m +done +``` + +Pretrained models, training logs, decoding logs, tensorboard and decoding results +are available at + + + + #### Pruned transducer stateless 7 [./pruned_transducer_stateless7](./pruned_transducer_stateless7) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md new file mode 120000 index 000000000..a784292cd --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/README.md \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py new file mode 100755 index 000000000..f5ae836fd --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -0,0 +1,735 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import ContextGraph +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding_method is modified_beam_search. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding_method is modified_beam_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + token_table: k2.SymbolTable, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + token_table: + It maps token ID to a string. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + feature_lens += 30 + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, 30), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + else: + hyp_tokens = [] + batch_size = encoder_out.size(0) + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyp_tokens.append(hyp) + + hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } + else: + key = f"beam_size_{params.beam_size}" + if params.has_contexts: + key += f"-context-score-{params.context_score}" + else: + key += "-no-context-words" + return {key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + token_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + token_table: + It maps a token ID to a string. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + token_table=token_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "modified_beam_search", + ) + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += "-no-contexts-words" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + if params.decoding_method == "modified_beam_search": + if os.path.exists(params.context_file): + contexts_text = [] + for line in open(params.context_file).readlines(): + contexts_text.append(line.strip()) + contexts = graph_compiler.texts_to_ids(contexts_text) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + + test_cuts = aishell.test_cuts() + dev_cuts = aishell.valid_cuts() + + test_dl = aishell.test_dataloaders(test_cuts) + dev_dl = aishell.test_dataloaders(dev_cuts) + + test_sets = ["test", "dev"] + test_dls = [test_dl, dev_dl] + + import time + + for test_set, test_dl in zip(test_sets, test_dls): + start = time.time() + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + token_table=lexicon.token_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + ) + logging.info(f"Elasped time for {test_set}: {time.time() - start}") + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode_stream.py new file mode 120000 index 000000000..ca8fed319 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py new file mode 120000 index 000000000..72e43c297 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py new file mode 120000 index 000000000..3b36924ef --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py new file mode 120000 index 000000000..eca5e2956 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx.py new file mode 120000 index 000000000..57a0cd0a0 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py new file mode 120000 index 000000000..2acafdc61 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py new file mode 120000 index 000000000..5d9c6ba00 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py new file mode 120000 index 000000000..457131699 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py new file mode 120000 index 000000000..2b8fa3cbb --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py new file mode 120000 index 000000000..e17d4f734 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py new file mode 120000 index 000000000..8eea90e04 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_check.py new file mode 120000 index 000000000..28bf7bb82 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py new file mode 120000 index 000000000..c8548d459 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py new file mode 120000 index 000000000..ae4d9bb04 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/pretrained.py new file mode 120000 index 000000000..9510b8fde --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py new file mode 120000 index 000000000..92c3904af --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py new file mode 120000 index 000000000..1199a61d6 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py new file mode 100755 index 000000000..6b4f183cf --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -0,0 +1,627 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +import os +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall import ContextGraph +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + token_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + [ + token_table[result] + for result in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + [ + token_table[result] + for result in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + if params.decoding_method == "modified_beam_search": + if os.path.exists(params.context_file): + contexts_text = [] + for line in open(params.context_file).readlines(): + contexts_text.append(line.strip()) + contexts = graph_compiler.texts_to_ids(contexts_text) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + aishell = AishellAsrDataModule(args) + + test_cuts = aishell.test_cuts() + valid_cuts = aishell.valid_cuts() + + test_sets = ["test", "valid"] + cuts = [test_cuts, valid_cuts] + + for test_set, test_cut in zip(test_sets, cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + token_table=lexicon.token_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/test_model.py new file mode 120000 index 000000000..1259849e0 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/test_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py new file mode 100755 index 000000000..2e1044658 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py @@ -0,0 +1,1251 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + oov="", + ) + + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + aishell = AishellAsrDataModule(args) + + train_cuts = aishell.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + # T = ((c.num_frames - 7) // 2 + 1) // 2 + # tokens = sp.encode(c.supervisions[0].text, out_type=str) + + # if T < len(tokens): + # logging.warning( + # f"Exclude cut with ID {c.id} from training. " + # f"Number of frames (before subsampling): {c.num_frames}. " + # f"Number of frames (after subsampling): {T}. " + # f"Text: {c.supervisions[0].text}. " + # f"Tokens: {tokens}. " + # f"Number of tokens: {len(tokens)}" + # ) + return False + + return True + + # train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = aishell.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = aishell.valid_cuts() + valid_dl = aishell.valid_dataloaders(valid_cuts) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # graph_compiler=graph_compiler, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = graph_compiler.texts_to_ids(supervisions["text"]) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py new file mode 100755 index 000000000..88eb34104 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -0,0 +1,1253 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer2 import Zipformer + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + is_pnnx=True, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + oov="", + ) + + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + aishell = AishellAsrDataModule(args) + + train_cuts = aishell.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + # T = ((c.num_frames - 7) // 2 + 1) // 2 + # tokens = sp.encode(c.supervisions[0].text, out_type=str) + + # if T < len(tokens): + # logging.warning( + # f"Exclude cut with ID {c.id} from training. " + # f"Number of frames (before subsampling): {c.num_frames}. " + # f"Number of frames (after subsampling): {T}. " + # f"Text: {c.supervisions[0].text}. " + # f"Tokens: {tokens}. " + # f"Number of tokens: {len(tokens)}" + # ) + # return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = aishell.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = aishell.valid_cuts() + valid_dl = aishell.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = graph_compiler.texts_to_ids(supervisions["text"]) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer.py new file mode 120000 index 000000000..ec183baa7 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py new file mode 120000 index 000000000..12dbda888 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file From d2bd0933b1462fefdc7ac2b41881ae0eb71be873 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 17 Oct 2023 21:22:32 +0800 Subject: [PATCH 068/113] Compatibility with the latest Lhotse (#1314) --- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 2 +- .../ASR/transducer_stateless_modified-2/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 2 +- .../ASR_v2/pruned_transducer_stateless7/asr_datamodule.py | 3 +-- egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py | 2 +- egs/ami/SURT/dprnn_zipformer/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless7/asr_datamodule.py | 2 +- egs/csj/ASR/local/utils/asr_datamodule.py | 2 +- egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 2 +- egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py | 2 +- egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless3/asr_datamodule.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py | 2 +- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 +- egs/mgb2/ASR/conformer_ctc/asr_datamodule.py | 3 +-- egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 3 +-- egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 2 +- egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py | 2 +- egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 2 +- 26 files changed, 26 insertions(+), 29 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index 49a697bfd..3667c2ad0 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -211,7 +211,7 @@ class Aidatatang_200zhAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py index 9c6021a19..cd8dd821c 100644 --- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py @@ -160,7 +160,7 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py index af37cc175..8f6a88f59 100644 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -218,7 +218,7 @@ class AiShell2AsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index da9da371e..4ad98fb51 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -228,7 +228,7 @@ class Aishell4AsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py index 4799da19d..5ad80817a 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -211,7 +211,7 @@ class AlimeetingAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py index 1cfd053c7..9d288218a 100644 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py @@ -208,7 +208,7 @@ class AlimeetingAsrDataModule: logging.info("Enable MUSAN") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -288,7 +288,6 @@ class AlimeetingAsrDataModule: return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] if self.args.concatenate_cuts: transforms = [ diff --git a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py index f7ee9c962..79474f1d8 100644 --- a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -214,7 +214,7 @@ class AmiAsrDataModule: logging.info("Enable MUSAN") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py index 3dd786d33..1549c1631 100644 --- a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py +++ b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py @@ -202,7 +202,7 @@ class AmiAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py index 73f2f1dce..546e9f9dd 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -230,7 +230,7 @@ class CommonVoiceAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/csj/ASR/local/utils/asr_datamodule.py b/egs/csj/ASR/local/utils/asr_datamodule.py index 272486227..042b6ecbf 100644 --- a/egs/csj/ASR/local/utils/asr_datamodule.py +++ b/egs/csj/ASR/local/utils/asr_datamodule.py @@ -256,7 +256,7 @@ class CSJAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.musan_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index 9d6e3c42a..a93e224d5 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -194,7 +194,7 @@ class GigaSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 29e72b408..b5b27ce95 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -217,7 +217,7 @@ class GigaSpeechAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py index a72df89e0..c1abdbdb5 100644 --- a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py +++ b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py @@ -204,7 +204,7 @@ class LibriCssAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index f8f558ce1..ee7556e49 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -209,7 +209,7 @@ class LibriSpeechAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.json.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index b7735be85..057624272 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -164,7 +164,7 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py index 75e153cb0..cd432fd6f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py @@ -217,7 +217,7 @@ class GigaSpeechAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 20df469da..c500eb3e5 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -233,7 +233,7 @@ class LibriSpeechAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py index 442ff85c2..7753d1674 100644 --- a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py @@ -182,7 +182,6 @@ class MGB2AsrDataModule: cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, ) -> DataLoader: - transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") @@ -190,7 +189,7 @@ class MGB2AsrDataModule: cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py index 3d58ebf3a..02cfa1346 100644 --- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py @@ -219,7 +219,7 @@ class AsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index d94a92503..cf70fc0f8 100644 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -182,7 +182,7 @@ class SPGISpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -261,7 +261,6 @@ class SPGISpeechAsrDataModule: return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] if self.args.concatenate_cuts: transforms = [ diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py index aeeb2ef78..ce8634a1d 100644 --- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -220,7 +220,7 @@ class SwitchBoardAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py index 39beffdcf..5269a1778 100644 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -228,7 +228,7 @@ class TAL_CSASRAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py index 28d0d3826..d4a9e4bc9 100644 --- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py +++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py @@ -205,7 +205,7 @@ class TedLiumAsrDataModule: logging.info("Enable MUSAN") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py index 7c299d601..5d1b3c367 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -157,7 +157,7 @@ class TimitAsrDataModule(DataModule): cuts_musan = load_manifest(self.args.feature_dir / "musan_cuts.jsonl.gz") logging.info("About to create train dataset") - transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] + transforms = [CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20))] if self.args.concatenate_cuts: logging.info( f"Using cut concatenation with duration factor " diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index c5967f10a..1dbfb9709 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -215,7 +215,7 @@ class WenetSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py index 6362ab7cd..7594fb28e 100644 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -218,7 +218,7 @@ class Xbmu_AmdoAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") From 807816fec0dde1bfa0f0a2f20d36552cc3d84a90 Mon Sep 17 00:00:00 2001 From: Erwan Zerhouni <61225408+ezerhouni@users.noreply.github.com> Date: Wed, 18 Oct 2023 10:07:10 +0200 Subject: [PATCH 069/113] Fix chunk issue for sherpa (#1316) --- egs/librispeech/ASR/zipformer/zipformer.py | 31 +++++++++++++--------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 1a174b315..61ae378d8 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -17,28 +17,33 @@ # limitations under the License. import copy +import logging import math +import random import warnings from typing import List, Optional, Tuple, Union -import logging + import torch -import random from encoder_interface import EncoderInterface from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +) +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationDropoutAndLinear, Balancer, BiasNorm, - Dropout2, ChunkCausalDepthwiseConv1d, - ActivationDropoutAndLinear, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + Dropout2, + FloatLike, + ScheduledFloat, Whiten, - Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + convert_num_channels, + limit_param_value, penalize_abs_values_gt, softmax, - ScheduledFloat, - FloatLike, - limit_param_value, - convert_num_channels, ) from torch import Tensor, nn @@ -2098,7 +2103,7 @@ class NonlinAttention(nn.Module): (seq_len, batch_size, _) = x.shape hidden_channels = self.hidden_channels - s, x, y = x.chunk(3, dim=-1) + s, x, y = x.chunk(3, dim=2) # s will go through tanh. @@ -2151,7 +2156,7 @@ class NonlinAttention(nn.Module): (seq_len, batch_size, _) = x.shape hidden_channels = self.hidden_channels - s, x, y = x.chunk(3, dim=-1) + s, x, y = x.chunk(3, dim=2) # s will go through tanh. s = self.tanh(s) @@ -2308,7 +2313,7 @@ class ConvolutionModule(nn.Module): x = self.in_proj(x) # (time, batch, 2*channels) - x, s = x.chunk(2, dim=-1) + x, s = x.chunk(2, dim=2) s = self.balancer1(s) s = self.sigmoid(s) x = self.activation1(x) # identity. From 52c24df61da3d04a6fdcab32d5615c394951279b Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 18 Oct 2023 17:36:14 +0800 Subject: [PATCH 070/113] Fix model avg (#1317) * fix a bug about the model_avg during finetuning by exchanging the order of loading pre-trained model and initializing avg model * only match the exact module prefix --- .../ASR/pruned_transducer_stateless7/finetune.py | 11 +++++++++-- .../ASR/pruned_transducer_stateless2/finetune.py | 8 ++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index 4e261dbc1..a7a8ef149 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -655,8 +655,12 @@ def load_model_params( dst_state_dict = model.state_dict() for module in init_modules: logging.info(f"Loading parameters starting with prefix {module}") - src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] - dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] assert set(src_keys) == set(dst_keys) # two sets should match exactly for key in src_keys: dst_state_dict[key] = src_state_dict.pop(key) @@ -1089,6 +1093,9 @@ def run(rank, world_size, args): checkpoints = load_model_params( ckpt=params.finetune_ckpt, model=model, init_modules=modules ) + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) else: assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index c943a84af..ba91980d3 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -498,8 +498,12 @@ def load_model_params( dst_state_dict = model.state_dict() for module in init_modules: logging.info(f"Loading parameters starting with prefix {module}") - src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] - dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] assert set(src_keys) == set(dst_keys) # two sets should match exactly for key in src_keys: dst_state_dict[key] = src_state_dict.pop(key) From 98c5286404a0add86bc6243171fc092ea89c51bb Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Thu, 19 Oct 2023 01:13:50 +0900 Subject: [PATCH 071/113] Fix typo in code-style.rst (#1318) --- docs/source/contributing/code-style.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/contributing/code-style.rst b/docs/source/contributing/code-style.rst index 3baaaeec2..cb08229c3 100644 --- a/docs/source/contributing/code-style.rst +++ b/docs/source/contributing/code-style.rst @@ -38,7 +38,7 @@ Please fix any issues reported by the check tools. .. HINT:: Some of the check tools, i.e., ``black`` and ``isort`` will modify - the files to be commited **in-place**. So please run ``git status`` + the files to be committed **in-place**. So please run ``git status`` after failure to see which file has been modified by the tools before you make any further changes. From 36c60b0cf6b172e7739f5b177e731faa03737967 Mon Sep 17 00:00:00 2001 From: Surav Shrestha <98219089+suravshresth@users.noreply.github.com> Date: Thu, 19 Oct 2023 09:00:18 +0545 Subject: [PATCH 072/113] fix typos in icefall/utils.py (#1319) --- icefall/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/icefall/utils.py b/icefall/utils.py index 410340d9d..6479d8f87 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1447,7 +1447,7 @@ def get_parameter_groups_with_lrs( This is for use with the ScaledAdam optimizers (more recent versions that accept lists of named-parameters; we can, if needed, create a version without the names). - It provides a way to specifiy learning-rate scales inside the module, so that if + It provides a way to specify learning-rate scales inside the module, so that if any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will scale the LR of any parameters inside that module or its submodules. Note: you can set module parameters outside the __init__ function, e.g.: @@ -1607,10 +1607,10 @@ def tokenize_by_bpe_model( chars = pattern.split(txt.upper()) mix_chars = [w for w in chars if len(w.strip()) > 0] for ch_or_w in mix_chars: - # ch_or_w is a single CJK charater(i.e., "你"), do nothing. + # ch_or_w is a single CJK character(i.e., "你"), do nothing. if pattern.fullmatch(ch_or_w) is not None: tokens.append(ch_or_w) - # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), + # ch_or_w contains non-CJK characters(i.e., " IT'S OKAY "), # encode ch_or_w using bpe_model. else: for p in sp.encode_as_pieces(ch_or_w): @@ -1624,7 +1624,7 @@ def tokenize_by_CJK_char(line: str) -> str: """ Tokenize a line of text with CJK char. - Note: All return charaters will be upper case. + Note: All return characters will be upper case. Example: input = "你好世界是 hello world 的中文" @@ -1917,7 +1917,7 @@ def parse_bpe_timestamps_and_texts( A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. containing multiple FSAs, which is expected to be the result of k2.shortest_path (otherwise the returned values won't - be meaningful). Its attribtutes `labels` and `aux_labels` + be meaningful). Its attributes `labels` and `aux_labels` are both BPE tokens. sp: The BPE model. @@ -2045,7 +2045,7 @@ def parse_fsa_timestamps_and_texts( ) -> Tuple[List[Tuple[float, float]], List[List[str]]]: """Parse timestamps (in seconds) and texts for given decoded fsa paths. Currently it supports two cases: - (1) ctc-decoding, the attribtutes `labels` and `aux_labels` + (1) ctc-decoding, the attributes `labels` and `aux_labels` are both BPE tokens. In this case, sp should be provided. (2) HLG-based 1best, the attribtute `labels` is the prediction unit, e.g., phone or BPE tokens; attribute `aux_labels` is the word index. From ce372cce33ad7594baf603f75264950d88fa329c Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 19 Oct 2023 17:24:31 +0800 Subject: [PATCH 073/113] Update documentation to PromptASR (#1321) --- .../zipformer_prompt_asr/train_baseline.py | 45 +++++++++++++++---- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py index 7075c9154..32302602c 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 # Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) +# Mingshuang Luo +# Zengwei Yao, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -21,21 +22,35 @@ Usage: -# For mix precision training: +# For mix precision training, using MCP style transcript: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./zipformer/train.py \ +./zipformer_prompt_asr/train_baseline.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ - --exp-dir zipformer/exp \ + --exp-dir zipformer_prompt_asr/exp \ + --transcript-style MCP \ + --max-duration 1000 + +# For mix precision training, using UC style transcript: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./zipformer_prompt_asr/train_baseline.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer_prompt_asr/exp \ + --transcript-style UC \ --max-duration 1000 # To train a streaming model -./zipformer/train.py \ +./zipformer_prompt_asr/train_baseline.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ @@ -100,7 +115,7 @@ from icefall.utils import ( LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] -def get_first( +def get_mixed_cased_with_punc( texts: List[str], pre_texts: List[str], context_list: Optional[str] = None, @@ -479,6 +494,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--transcript-style", + type=str, + default="UC", + choices=["UC", "MCP"], + help="""The transcript style used for training. UC stands for upper-cased text w/o punctuations, + MCP stands for mix-cased text with punctuation. + """, + ) + add_model_arguments(parser) return parser @@ -1223,7 +1248,11 @@ def run(rank, world_size, args): else: sampler_state_dict = None - text_sampling_func = get_upper_only_alpha + if params.transcript_style == "UC": + text_sampling_func = get_upper_only_alpha + else: + text_sampling_func = get_mixed_cased_with_punc + logging.info(f"Using {params.transcript_style} style for training.") logging.info(f"Text sampling func: {text_sampling_func}") train_dl = libriheavy.train_dataloaders( train_cuts, From 543b4cc1ca45f5a6e273cb1440a233e5fc51fa36 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Thu, 19 Oct 2023 15:53:31 +0200 Subject: [PATCH 074/113] small enhanecements (#1322) - add extra check of 'x' and 'x_lens' to earlier point in Transducer model - specify 'utf' encoding when opening text files for writing (recogs, errs) --- egs/librispeech/ASR/pruned_transducer_stateless7/model.py | 3 +++ icefall/utils.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 0e59b0f2f..add0e6a18 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -114,6 +114,9 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 + # x.T_dim == max(x_len) + assert x.size(1) == x_lens.max().item(), (x.shape, x_lens, x_lens.max()) + encoder_out, x_lens = self.encoder(x, x_lens) assert torch.all(x_lens > 0) diff --git a/icefall/utils.py b/icefall/utils.py index 6479d8f87..399e8d8b3 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -498,7 +498,7 @@ def store_transcripts( Returns: Return None. """ - with open(filename, "w") as f: + with open(filename, "w", encoding="utf8") as f: for cut_id, ref, hyp in texts: if char_level: ref = list("".join(ref)) @@ -523,7 +523,7 @@ def store_transcripts_and_timestamps( Returns: Return None. """ - with open(filename, "w") as f: + with open(filename, "w", encoding="utf8") as f: for cut_id, ref, hyp, time_ref, time_hyp in texts: print(f"{cut_id}:\tref={ref}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f) From 973dc1026d93c5ce551428459077187a3cd1e0a9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 19 Oct 2023 22:54:00 +0800 Subject: [PATCH 075/113] Make diagnostics.py more error-tolerant and have wider range of supported torch versions (#1234) --- icefall/diagnostics.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 700dc1500..ebf61784e 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -244,12 +244,22 @@ class TensorDiagnostic(object): if stats_type == "eigs": try: - eigs, _ = torch.symeig(stats) + if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'): + eigs, _ = torch.linalg.eigh(stats) + else: + eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() except: # noqa - print("Error getting eigenvalues, trying another method.") - eigs, _ = torch.eig(stats) - stats = eigs.norm(dim=1).sqrt() + print( + "Error getting eigenvalues, trying another method." + ) + if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'): + eigs, _ = torch.linalg.eig(stats) + eigs = eigs.abs() + else: + eigs, _ = torch.eig(stats) + eigs = eigs.norm(dim=1) + stats = eigs.sqrt() # sqrt so it reflects data magnitude, like stddev- not variance if stats_type in ["rms", "stddev"]: @@ -569,11 +579,10 @@ def attach_diagnostics( ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if o.dtype in (torch.float32, torch.float16, torch.float64): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate( - o, class_name=get_class_name(_module) - ) - + if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): + _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, + class_name=get_class_name(_module)) + def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] @@ -587,11 +596,9 @@ def attach_diagnostics( ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if o.dtype in (torch.float32, torch.float16, torch.float64): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate( - o, class_name=get_class_name(_module) - ) - + if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, + class_name=get_class_name(_module)) module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) From eef47adee9aa765f41cd63a8d57049b02849f3ad Mon Sep 17 00:00:00 2001 From: Rudra <92840555+Rudra-Ji@users.noreply.github.com> Date: Thu, 19 Oct 2023 20:24:43 +0530 Subject: [PATCH 076/113] fix typo (#1324) --- docs/source/decoding-with-langugage-models/LODR.rst | 2 +- docs/source/model-export/export-ncnn-conv-emformer.rst | 2 +- egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py | 2 +- egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py | 2 +- icefall/utils.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/decoding-with-langugage-models/LODR.rst b/docs/source/decoding-with-langugage-models/LODR.rst index 8cc1a624c..b6b6e8cbb 100644 --- a/docs/source/decoding-with-langugage-models/LODR.rst +++ b/docs/source/decoding-with-langugage-models/LODR.rst @@ -56,7 +56,7 @@ during decoding for transducer model: \lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) - \lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right) -In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR, +In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Compared to DR, the only difference lies in the choice of source domain LM. According to the original `paper `_, LODR achieves similar performance compared DR in both intra-domain and cross-domain settings. As a bi-gram is much faster to evaluate, LODR is usually much faster. diff --git a/docs/source/model-export/export-ncnn-conv-emformer.rst b/docs/source/model-export/export-ncnn-conv-emformer.rst index 4f5535d83..93392aee7 100644 --- a/docs/source/model-export/export-ncnn-conv-emformer.rst +++ b/docs/source/model-export/export-ncnn-conv-emformer.rst @@ -125,7 +125,7 @@ Python code. We have also set up ``PATH`` so that you can use .. caution:: Please don't use ``_. - We have made some modifications to the offical `ncnn`_. + We have made some modifications to the official `ncnn`_. We will synchronize ``_ periodically with the official one. diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index bdd1f27bc..2bafe25d6 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -203,7 +203,7 @@ def get_parser(): "--beam-size", type=int, default=4, - help="""An interger indicating how many candidates we will keep for each + help="""An integer indicating how many candidates we will keep for each frame. Used only when --decoding-method is beam_search or modified_beam_search.""", ) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index ba91980d3..c34f1593d 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -78,7 +78,7 @@ def add_finetune_arguments(parser: argparse.ArgumentParser): default=None, help=""" Modules to be initialized. It matches all parameters starting with - a specific key. The keys are given with Comma seperated. If None, + a specific key. The keys are given with Comma separated. If None, all modules will be initialised. For example, if you only want to initialise all parameters staring with "encoder", use "encoder"; if you want to initialise parameters starting with encoder or decoder, diff --git a/icefall/utils.py b/icefall/utils.py index 399e8d8b3..a9e8a81b9 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1977,7 +1977,7 @@ def parse_timestamps_and_texts( A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. containing multiple FSAs, which is expected to be the result of k2.shortest_path (otherwise the returned values won't - be meaningful). Attribtute `labels` is the prediction unit, + be meaningful). Attribute `labels` is the prediction unit, e.g., phone or BPE tokens. Attribute `aux_labels` is the word index. word_table: The word symbol table. From 416852e8a16f7f7f3104e95271c6d109088a416d Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Sat, 21 Oct 2023 02:36:59 -0500 Subject: [PATCH 077/113] Add Zipformer recipe for GigaSpeech (#1254) Co-authored-by: Yifan Yang Co-authored-by: yfy62 --- .../run-gigaspeech-zipformer-2023-10-17.sh | 94 ++ .../run-gigaspeech-zipformer-2023-10-17.yml | 126 ++ README.md | 16 +- egs/gigaspeech/ASR/README.md | 1 + egs/gigaspeech/ASR/RESULTS.md | 74 + .../ASR/zipformer/asr_datamodule.py | 444 ++++++ egs/gigaspeech/ASR/zipformer/beam_search.py | 1 + egs/gigaspeech/ASR/zipformer/ctc_decode.py | 847 +++++++++++ egs/gigaspeech/ASR/zipformer/decode.py | 1065 +++++++++++++ egs/gigaspeech/ASR/zipformer/decode_stream.py | 1 + egs/gigaspeech/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + .../ASR/zipformer/export-onnx-ctc.py | 1 + .../ASR/zipformer/export-onnx-streaming.py | 1 + egs/gigaspeech/ASR/zipformer/export-onnx.py | 620 ++++++++ egs/gigaspeech/ASR/zipformer/export.py | 522 +++++++ .../ASR/zipformer/gigaspeech_scoring.py | 1 + .../ASR/zipformer/jit_pretrained.py | 1 + .../ASR/zipformer/jit_pretrained_ctc.py | 1 + .../ASR/zipformer/jit_pretrained_streaming.py | 1 + egs/gigaspeech/ASR/zipformer/joiner.py | 1 + egs/gigaspeech/ASR/zipformer/model.py | 1 + egs/gigaspeech/ASR/zipformer/onnx_check.py | 1 + egs/gigaspeech/ASR/zipformer/onnx_decode.py | 1 + .../zipformer/onnx_pretrained-streaming.py | 1 + .../ASR/zipformer/onnx_pretrained.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc_H.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc_HL.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc_HLG.py | 1 + egs/gigaspeech/ASR/zipformer/optim.py | 1 + egs/gigaspeech/ASR/zipformer/pretrained.py | 1 + .../ASR/zipformer/pretrained_ctc.py | 1 + egs/gigaspeech/ASR/zipformer/profile.py | 1 + egs/gigaspeech/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + .../ASR/zipformer/streaming_beam_search.py | 1 + .../ASR/zipformer/streaming_decode.py | 853 +++++++++++ egs/gigaspeech/ASR/zipformer/subsampling.py | 1 + egs/gigaspeech/ASR/zipformer/test_scaling.py | 1 + .../ASR/zipformer/test_subsampling.py | 1 + egs/gigaspeech/ASR/zipformer/train.py | 1345 +++++++++++++++++ egs/gigaspeech/ASR/zipformer/zipformer.py | 1 + 43 files changed, 6036 insertions(+), 2 deletions(-) create mode 100755 .github/scripts/run-gigaspeech-zipformer-2023-10-17.sh create mode 100644 .github/workflows/run-gigaspeech-zipformer-2023-10-17.yml create mode 100644 egs/gigaspeech/ASR/zipformer/asr_datamodule.py create mode 120000 egs/gigaspeech/ASR/zipformer/beam_search.py create mode 100755 egs/gigaspeech/ASR/zipformer/ctc_decode.py create mode 100755 egs/gigaspeech/ASR/zipformer/decode.py create mode 120000 egs/gigaspeech/ASR/zipformer/decode_stream.py create mode 120000 egs/gigaspeech/ASR/zipformer/decoder.py create mode 120000 egs/gigaspeech/ASR/zipformer/encoder_interface.py create mode 120000 egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py create mode 120000 egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py create mode 100755 egs/gigaspeech/ASR/zipformer/export-onnx.py create mode 100755 egs/gigaspeech/ASR/zipformer/export.py create mode 120000 egs/gigaspeech/ASR/zipformer/gigaspeech_scoring.py create mode 120000 egs/gigaspeech/ASR/zipformer/jit_pretrained.py create mode 120000 egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py create mode 120000 egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py create mode 120000 egs/gigaspeech/ASR/zipformer/joiner.py create mode 120000 egs/gigaspeech/ASR/zipformer/model.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_check.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_decode.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py create mode 120000 egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py create mode 120000 egs/gigaspeech/ASR/zipformer/optim.py create mode 120000 egs/gigaspeech/ASR/zipformer/pretrained.py create mode 120000 egs/gigaspeech/ASR/zipformer/pretrained_ctc.py create mode 120000 egs/gigaspeech/ASR/zipformer/profile.py create mode 120000 egs/gigaspeech/ASR/zipformer/scaling.py create mode 120000 egs/gigaspeech/ASR/zipformer/scaling_converter.py create mode 120000 egs/gigaspeech/ASR/zipformer/streaming_beam_search.py create mode 100755 egs/gigaspeech/ASR/zipformer/streaming_decode.py create mode 120000 egs/gigaspeech/ASR/zipformer/subsampling.py create mode 120000 egs/gigaspeech/ASR/zipformer/test_scaling.py create mode 120000 egs/gigaspeech/ASR/zipformer/test_subsampling.py create mode 100755 egs/gigaspeech/ASR/zipformer/train.py create mode 120000 egs/gigaspeech/ASR/zipformer/zipformer.py diff --git a/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh b/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh new file mode 100755 index 000000000..6bb0b9ebc --- /dev/null +++ b/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/gigaspeech/ASR + +repo_url=https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17 + +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "data/lang_bpe_500/tokens.txt" +git lfs pull --include "exp/jit_script.pt" +git lfs pull --include "exp/pretrained.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Export to torchscript model" +./zipformer/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./zipformer/jit_pretrained.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --nn-model-filename $repo/exp/jit_script.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +for method in greedy_search modified_beam_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p zipformer/exp + ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh zipformer/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./zipformer/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --exp-dir zipformer/exp + done + + rm zipformer/exp/*.pt +fi diff --git a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml new file mode 100644 index 000000000..7572f4b5f --- /dev/null +++ b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml @@ -0,0 +1,126 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-gigaspeech-zipformer-2023-10-17 +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: run_gigaspeech_2023_10_17_zipformer-${{ github.ref }} + cancel-in-progress: true + +jobs: + run_gigaspeech_2023_10_17_zipformer: + if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2023-05-22 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/gigaspeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/gigaspeech/ASR/data/fbank + ls -lh egs/gigaspeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-gigaspeech-zipformer-2023-10-17.sh + + - name: Display decoding results for gigaspeech zipformer + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/gigaspeech/ASR/ + tree ./zipformer/exp + + cd zipformer + echo "results for zipformer" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for gigaspeech zipformer + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11 + path: egs/gigaspeech/ASR/zipformer/exp/ diff --git a/README.md b/README.md index da446109d..a14abd023 100644 --- a/README.md +++ b/README.md @@ -148,8 +148,11 @@ in the decoding. ### GigaSpeech -We provide two models for this recipe: [Conformer CTC model][GigaSpeech_conformer_ctc] -and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2]. +We provide three models for this recipe: + +- [Conformer CTC model][GigaSpeech_conformer_ctc] +- [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2]. +- [Transducer: Zipformer encoder + Embedding decoder][GigaSpeech_zipformer] #### Conformer CTC @@ -165,6 +168,14 @@ and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned R | fast beam search | 10.50 | 10.69 | | modified beam search | 10.40 | 10.51 | +#### Transducer: Zipformer encoder + Embedding decoder + +| | Dev | Test | +|----------------------|-------|-------| +| greedy search | 10.31 | 10.50 | +| fast beam search | 10.26 | 10.48 | +| modified beam search | 10.25 | 10.38 | + ### Aishell @@ -378,6 +389,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [TED-LIUM3_pruned_transducer_stateless]: egs/tedlium3/ASR/pruned_transducer_stateless [GigaSpeech_conformer_ctc]: egs/gigaspeech/ASR/conformer_ctc [GigaSpeech_pruned_transducer_stateless2]: egs/gigaspeech/ASR/pruned_transducer_stateless2 +[GigaSpeech_zipformer]: egs/gigaspeech/ASR/zipformer [Aidatatang_200zh_pruned_transducer_stateless2]: egs/aidatatang_200zh/ASR/pruned_transducer_stateless2 [WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2 [WenetSpeech_pruned_transducer_stateless5]: egs/wenetspeech/ASR/pruned_transducer_stateless5 diff --git a/egs/gigaspeech/ASR/README.md b/egs/gigaspeech/ASR/README.md index 32a0457c6..f0d60898c 100644 --- a/egs/gigaspeech/ASR/README.md +++ b/egs/gigaspeech/ASR/README.md @@ -15,6 +15,7 @@ ln -sfv /path/to/GigaSpeech download/GigaSpeech ## Performance Record | | Dev | Test | |--------------------------------|-------|-------| +| `zipformer` | 10.25 | 10.38 | | `conformer_ctc` | 10.47 | 10.58 | | `pruned_transducer_stateless2` | 10.40 | 10.51 | diff --git a/egs/gigaspeech/ASR/RESULTS.md b/egs/gigaspeech/ASR/RESULTS.md index 7ab565844..841ebdcfa 100644 --- a/egs/gigaspeech/ASR/RESULTS.md +++ b/egs/gigaspeech/ASR/RESULTS.md @@ -1,4 +1,78 @@ ## Results +### zipformer (zipformer + pruned stateless transducer) + +See for more details. + +[zipformer](./zipformer) + +- Non-streaming +- normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +The tensorboard log for training is available at + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|----------------------|------------|------------|--------------------| +| greedy_search | 10.31 | 10.50 | --epoch 30 --avg 9 | +| modified_beam_search | 10.25 | 10.38 | --epoch 30 --avg 9 | +| fast_beam_search | 10.26 | 10.48 | --epoch 30 --avg 9 | + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 0 \ + --subset XL \ + --max-duration 700 \ + --use-transducer 1 \ + --use-ctc 0 \ + --lr-epochs 1 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES=0 + +# greedy search +./zipformer/decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir ./zipformer/exp \ + --max-duration 1000 \ + --decoding-method greedy_search + +# modified beam search +./zipformer/decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir ./zipformer/exp \ + --max-duration 1000 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +# fast beam search (one best) +./zipformer/decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir ./zipformer/exp \ + --max-duration 1000 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +``` + ### GigaSpeech BPE training results (Pruned Transducer 2) #### 2022-05-12 diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py new file mode 100644 index 000000000..c4472ed23 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -0,0 +1,444 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import glob +import inspect +import logging +import re +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import lhotse +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class GigaSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + # GigaSpeech specific arguments + group.add_argument( + "--subset", + type=str, + default="XL", + help="Select the GigaSpeech subset (XS|S|M|L|XL)", + ) + group.add_argument( + "--small-dev", + type=str2bool, + default=False, + help="Should we use only 1000 utterances for dev (speeds up training)", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info(f"About to get train {self.args.subset} cuts") + if self.args.subset == "XL": + filenames = glob.glob( + f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz" + ) + pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") + idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) + idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) + sorted_filenames = [f[1] for f in idx_filenames] + logging.info( + f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode" + ) + cuts_train = lhotse.combine( + lhotse.load_manifest_lazy(p) for p in sorted_filenames + ) + else: + path = ( + self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz" + ) + cuts_train = CutSet.from_jsonl_lazy(path) + return cuts_train + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" + ) + if self.args.small_dev: + return cuts_valid.subset(first=1000) + else: + return cuts_valid + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" + ) diff --git a/egs/gigaspeech/ASR/zipformer/beam_search.py b/egs/gigaspeech/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/ctc_decode.py b/egs/gigaspeech/ASR/zipformer/ctc_decode.py new file mode 100755 index 000000000..aa51036d5 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/ctc_decode.py @@ -0,0 +1,847 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-decoding +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-decoding + +(2) 1best +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method 1best + +(3) nbest +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method nbest + +(4) nbest-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method nbest-rescoring + +(5) whole-lattice-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method whole-lattice-rescoring +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from train import add_model_arguments, get_params, get_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (3) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (4) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.6, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + + if params.decoding_method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + gigaspeech = GigaSpeechAsrDataModule(args) + + test_clean_cuts = gigaspeech.test_clean_cuts() + test_other_cuts = gigaspeech.test_other_cuts() + + test_clean_dl = gigaspeech.test_dataloaders(test_clean_cuts) + test_other_dl = gigaspeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/zipformer/decode.py b/egs/gigaspeech/ASR/zipformer/decode.py new file mode 100755 index 000000000..3a0c71484 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/decode.py @@ -0,0 +1,1065 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from gigaspeech_scoring import asr_text_post_processing +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) + + return parser + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(line.strip()) + context_graph = ContextGraph(params.context_score) + context_graph.build(sp.encode(contexts)) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + gigaspeech = GigaSpeechAsrDataModule(args) + + dev_cuts = gigaspeech.dev_cuts() + test_cuts = gigaspeech.test_cuts() + + dev_dl = gigaspeech.test_dataloaders(dev_cuts) + test_dl = gigaspeech.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/zipformer/decode_stream.py b/egs/gigaspeech/ASR/zipformer/decode_stream.py new file mode 120000 index 000000000..b8d8ddfc4 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/decoder.py b/egs/gigaspeech/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/encoder_interface.py b/egs/gigaspeech/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..653c5b09a --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py b/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py new file mode 120000 index 000000000..f9d756352 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py b/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/export-onnx.py b/egs/gigaspeech/ASR/zipformer/export-onnx.py new file mode 100755 index 000000000..0f78cfe5b --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/export-onnx.py @@ -0,0 +1,620 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/gigaspeech/ASR + +repo_url=https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__( + self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer2", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming zipformer2", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/gigaspeech/ASR/zipformer/export.py b/egs/gigaspeech/ASR/zipformer/export.py new file mode 100755 index 000000000..e45c96b57 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/export.py @@ -0,0 +1,522 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for gigaspeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. +You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. + +Check ./jit_pretrained_streaming.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +- For non-streaming model: + +To use the generated file with `zipformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/gigaspeech/ASR + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +- For streaming model: + +To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/gigaspeech/ASR + + # simulated streaming decoding + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + + # chunk-wise streaming decoding + ./zipformer/streaming_decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +- non-streaming model: +https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17 + # You will find the pre-trained models in exp dir +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor, nn +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +class StreamingEncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + assert len(encoder.chunk_size) == 1, encoder.chunk_size + assert len(encoder.left_context_frames) == 1, encoder.left_context_frames + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """Streaming forward for encoder_embed and encoder. + + Args: + features: (N, T, C) + feature_lengths: (N,) + states: a list of Tensors + + Returns encoder outputs, output lengths, and updated states. + """ + chunk_size = self.chunk_size + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lengths, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + # if torch.cuda.is_available(): + # device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + + # Wrap encoder and encoder_embed as a module + if params.causal: + model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) + chunk_size = model.encoder.chunk_size + left_context_len = model.encoder.left_context_len + filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + else: + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/gigaspeech/ASR/zipformer/gigaspeech_scoring.py b/egs/gigaspeech/ASR/zipformer/gigaspeech_scoring.py new file mode 120000 index 000000000..a6a4d12b1 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/gigaspeech_scoring.py @@ -0,0 +1 @@ +../conformer_ctc/gigaspeech_scoring.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained.py new file mode 120000 index 000000000..25108391f --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py new file mode 120000 index 000000000..9a8da5844 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py new file mode 120000 index 000000000..1962351e9 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/joiner.py b/egs/gigaspeech/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/model.py b/egs/gigaspeech/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_check.py b/egs/gigaspeech/ASR/zipformer/onnx_check.py new file mode 120000 index 000000000..f3dd42004 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_decode.py b/egs/gigaspeech/ASR/zipformer/onnx_decode.py new file mode 120000 index 000000000..0573b88c5 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py new file mode 120000 index 000000000..cfea104c2 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py new file mode 120000 index 000000000..a3183ebf6 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py new file mode 120000 index 000000000..a4fd76ac2 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py new file mode 120000 index 000000000..f805e3761 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py new file mode 120000 index 000000000..8343d5079 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/optim.py b/egs/gigaspeech/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/pretrained.py b/egs/gigaspeech/ASR/zipformer/pretrained.py new file mode 120000 index 000000000..0bd71dde4 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py new file mode 120000 index 000000000..c2f6f6fc3 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained_ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/profile.py b/egs/gigaspeech/ASR/zipformer/profile.py new file mode 120000 index 000000000..c93adbd14 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/profile.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/profile.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/scaling.py b/egs/gigaspeech/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/scaling_converter.py b/egs/gigaspeech/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/streaming_beam_search.py b/egs/gigaspeech/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/streaming_decode.py b/egs/gigaspeech/ASR/zipformer/streaming_decode.py new file mode 100755 index 000000000..a76788859 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/streaming_decode.py @@ -0,0 +1,853 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +from asr_datamodule import GigaSpeechAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + gigaspeech = GigaSpeechAsrDataModule(args) + + dev_cuts = gigaspeech.dev_cuts() + test_cuts = gigaspeech.test_cuts() + + test_sets = ["dev", "test"] + test_cuts = [dev_cuts, test_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/zipformer/subsampling.py b/egs/gigaspeech/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/test_scaling.py b/egs/gigaspeech/ASR/zipformer/test_scaling.py new file mode 120000 index 000000000..715798436 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/test_scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/test_subsampling.py b/egs/gigaspeech/ASR/zipformer/test_subsampling.py new file mode 120000 index 000000000..bf0ee3d11 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/test_subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/test_subsampling.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py new file mode 100755 index 000000000..d8ff4fecc --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -0,0 +1,1345 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=1, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 500, + "reset_interval": 2000, + "valid_interval": 20000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + gigaspeech = GigaSpeechAsrDataModule(args) + + train_cuts = gigaspeech.train_cuts() + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = gigaspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = gigaspeech.dev_cuts() + valid_dl = gigaspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/zipformer/zipformer.py b/egs/gigaspeech/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file From 902dc2364a693ce7c6b939a0c9cf64382f767147 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 22 Oct 2023 23:25:06 +0800 Subject: [PATCH 078/113] Update docker for torch 2.1 (#1326) --- .github/workflows/build-docker-image.yml | 9 ++- .github/workflows/run-docker-image.yml | 15 ++++- .github/workflows/run-yesno-recipe.yml | 4 +- docker/torch1.12.1-cuda11.3.dockerfile | 5 +- docker/torch1.13.0-cuda11.6.dockerfile | 5 +- docker/torch1.9.0-cuda10.2.dockerfile | 3 +- docker/torch2.0.0-cuda11.7.dockerfile | 5 +- docker/torch2.1.0-cuda11.8.dockerfile | 71 ++++++++++++++++++++++++ docker/torch2.1.0-cuda12.1.dockerfile | 71 ++++++++++++++++++++++++ docs/source/docker/intro.rst | 2 + 10 files changed, 179 insertions(+), 11 deletions(-) create mode 100644 docker/torch2.1.0-cuda11.8.dockerfile create mode 100644 docker/torch2.1.0-cuda12.1.dockerfile diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml index 327f0ee45..e5d96dcdf 100644 --- a/.github/workflows/build-docker-image.yml +++ b/.github/workflows/build-docker-image.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout @@ -30,6 +30,13 @@ jobs: image=${{ matrix.image }} mv -v ./docker/$image.dockerfile ./Dockerfile + - name: Free space + shell: bash + run: | + df -h + rm -rf /opt/hostedtoolcache + df -h + - name: Log in to Docker Hub uses: docker/login-action@v2 with: diff --git a/.github/workflows/run-docker-image.yml b/.github/workflows/run-docker-image.yml index 12604a132..d048923b6 100644 --- a/.github/workflows/run-docker-image.yml +++ b/.github/workflows/run-docker-image.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout - uses: actions/checkout@v2 @@ -30,8 +30,15 @@ jobs: uname -a cat /etc/*release + find / -name libcuda* 2>/dev/null + + ls -lh /usr/local/ + ls -lh /usr/local/cuda* + nvcc --version + ls -lh /usr/local/cuda-*/compat/* + # For torch1.9.0-cuda10.2 export LD_LIBRARY_PATH=/usr/local/cuda-10.2/compat:$LD_LIBRARY_PATH @@ -41,6 +48,12 @@ jobs: # For torch2.0.0-cuda11.7 export LD_LIBRARY_PATH=/usr/local/cuda-11.7/compat:$LD_LIBRARY_PATH + # For torch2.1.0-cuda11.8 + export LD_LIBRARY_PATH=/usr/local/cuda-11.8/compat:$LD_LIBRARY_PATH + + # For torch2.1.0-cuda12.1 + export LD_LIBRARY_PATH=/usr/local/cuda-12.1/compat:$LD_LIBRARY_PATH + which nvcc cuda_dir=$(dirname $(which nvcc)) diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 7d55a50e1..9ac848535 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -64,8 +64,8 @@ jobs: pip uninstall -y protobuf pip install --no-binary protobuf protobuf==3.20.* - pip install --no-deps --force-reinstall https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.3.dev20230508+cpu.torch1.13.1-cp38-cp38-linux_x86_64.whl - pip install kaldifeat==1.25.0.dev20230726+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html + pip install --no-deps --force-reinstall k2==1.24.4.dev20231021+cpu.torch1.13.1 -f https://k2-fsa.github.io/k2/cpu.html + pip install kaldifeat==1.25.1.dev20231022+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html - name: Run yesno recipe shell: bash diff --git a/docker/torch1.12.1-cuda11.3.dockerfile b/docker/torch1.12.1-cuda11.3.dockerfile index 5338bdca7..ed746abe3 100644 --- a/docker/torch1.12.1-cuda11.3.dockerfile +++ b/docker/torch1.12.1-cuda11.3.dockerfile @@ -4,8 +4,9 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive -ARG K2_VERSION="1.24.3.dev20230725+cuda11.3.torch1.12.1" -ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.3.torch1.12.1" +# python 3.7 +ARG K2_VERSION="1.24.4.dev20230725+cuda11.3.torch1.12.1" +ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.3.torch1.12.1" ARG TORCHAUDIO_VERSION="0.12.1+cu113" LABEL authors="Fangjun Kuang " diff --git a/docker/torch1.13.0-cuda11.6.dockerfile b/docker/torch1.13.0-cuda11.6.dockerfile index 4d2f96c8e..9657866e5 100644 --- a/docker/torch1.13.0-cuda11.6.dockerfile +++ b/docker/torch1.13.0-cuda11.6.dockerfile @@ -4,8 +4,9 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive -ARG K2_VERSION="1.24.3.dev20230725+cuda11.6.torch1.13.0" -ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.6.torch1.13.0" +# python 3.9 +ARG K2_VERSION="1.24.4.dev20231021+cuda11.6.torch1.13.0" +ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.6.torch1.13.0" ARG TORCHAUDIO_VERSION="0.13.0+cu116" LABEL authors="Fangjun Kuang " diff --git a/docker/torch1.9.0-cuda10.2.dockerfile b/docker/torch1.9.0-cuda10.2.dockerfile index a7cef6dc8..a92af7ad0 100644 --- a/docker/torch1.9.0-cuda10.2.dockerfile +++ b/docker/torch1.9.0-cuda10.2.dockerfile @@ -4,8 +4,9 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive +# python 3.7 ARG K2_VERSION="1.24.3.dev20230726+cuda10.2.torch1.9.0" -ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda10.2.torch1.9.0" +ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda10.2.torch1.9.0" ARG TORCHAUDIO_VERSION="0.9.0" LABEL authors="Fangjun Kuang " diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile index d91fbc24f..07296e6f0 100644 --- a/docker/torch2.0.0-cuda11.7.dockerfile +++ b/docker/torch2.0.0-cuda11.7.dockerfile @@ -4,8 +4,9 @@ ENV LC_ALL C.UTF-8 ARG DEBIAN_FRONTEND=noninteractive -ARG K2_VERSION="1.24.3.dev20230718+cuda11.7.torch2.0.0" -ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.7.torch2.0.0" +# python 3.10 +ARG K2_VERSION="1.24.4.dev20231021+cuda11.7.torch2.0.0" +ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.7.torch2.0.0" ARG TORCHAUDIO_VERSION="2.0.0+cu117" LABEL authors="Fangjun Kuang " diff --git a/docker/torch2.1.0-cuda11.8.dockerfile b/docker/torch2.1.0-cuda11.8.dockerfile new file mode 100644 index 000000000..e500e9a6a --- /dev/null +++ b/docker/torch2.1.0-cuda11.8.dockerfile @@ -0,0 +1,71 @@ +FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20231021+cuda11.8.torch2.1.0" +ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.8.torch2.1.0" +ARG TORCHAUDIO_VERSION="2.1.0+cu118" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.1.0-cuda12.1.dockerfile b/docker/torch2.1.0-cuda12.1.dockerfile new file mode 100644 index 000000000..c3f12323e --- /dev/null +++ b/docker/torch2.1.0-cuda12.1.dockerfile @@ -0,0 +1,71 @@ +FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20231021+cuda12.1.torch2.1.0" +ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda12.1.torch2.1.0" +ARG TORCHAUDIO_VERSION="2.1.0+cu121" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst index b09247d85..9ead0df00 100644 --- a/docs/source/docker/intro.rst +++ b/docs/source/docker/intro.rst @@ -30,6 +30,8 @@ which will give you something like below: .. code-block:: bash + "torch2.1.0-cuda12.1" + "torch2.1.0-cuda11.8" "torch2.0.0-cuda11.7" "torch1.12.1-cuda11.3" "torch1.9.0-cuda10.2" From 92ef561ff71e531f243ff432561851bb4b93390a Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 24 Oct 2023 01:10:50 +0800 Subject: [PATCH 079/113] Minor fixes for torch.jit.script support (#1329) --- egs/aishell/ASR/transducer_stateless/decoder.py | 4 ++++ egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py | 4 ++++ egs/librispeech/ASR/pruned_transducer_stateless/decoder.py | 4 ++++ egs/librispeech/ASR/transducer_stateless/decoder.py | 4 ++++ egs/librispeech/ASR/zipformer/decoder.py | 5 ++++- 5 files changed, 20 insertions(+), 1 deletion(-) diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py index 70e9e6c96..130f080ec 100644 --- a/egs/aishell/ASR/transducer_stateless/decoder.py +++ b/egs/aishell/ASR/transducer_stateless/decoder.py @@ -70,6 +70,10 @@ class Decoder(nn.Module): groups=embedding_dim, bias=False, ) + else: + # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` + # when inference with torch.jit.script and context_size == 1 + self.conv = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py index 93e0f9f7e..8a55eb5c8 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py @@ -95,6 +95,10 @@ class Decoder(nn.Module): max_abs=1.0, prob=0.05, ) + else: + # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` + # when inference with torch.jit.script and context_size == 1 + self.conv = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 49b82c433..03847b449 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -74,6 +74,10 @@ class Decoder(nn.Module): groups=embedding_dim, bias=False, ) + else: + # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` + # when inference with torch.jit.script and context_size == 1 + self.conv = nn.Identity() self.output_linear = nn.Linear(embedding_dim, vocab_size) def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index a182d91e2..ac6292f63 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -71,6 +71,10 @@ class Decoder(nn.Module): groups=embedding_dim, bias=False, ) + else: + # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` + # when inference with torch.jit.script and context_size == 1 + self.conv = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index e77e54118..492d63fc5 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -17,7 +17,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from scaling import Balancer @@ -95,6 +94,10 @@ class Decoder(nn.Module): max_abs=1.0, prob=0.05, ) + else: + # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` + # when inference with torch.jit.script and context_size == 1 + self.conv = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ From f9980aa606d2ea9bf3d73d65309fa161b2bc4765 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 24 Oct 2023 08:17:17 +0800 Subject: [PATCH 080/113] minor fixes (#1332) --- egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py | 1 + egs/librispeech/ASR/zipformer/decoder.py | 1 + 2 files changed, 2 insertions(+) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py index 8a55eb5c8..91f167204 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decoder.py @@ -99,6 +99,7 @@ class Decoder(nn.Module): # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` # when inference with torch.jit.script and context_size == 1 self.conv = nn.Identity() + self.balancer2 = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index 492d63fc5..7ce44495b 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -98,6 +98,7 @@ class Decoder(nn.Module): # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` # when inference with torch.jit.script and context_size == 1 self.conv = nn.Identity() + self.balancer2 = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ From 4b791ced78aa5b6d2ccc2d78458a3ed984b26e7f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 24 Oct 2023 10:38:56 +0800 Subject: [PATCH 081/113] Fix CI tests (#1333) --- requirements-ci.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-ci.txt b/requirements-ci.txt index 1eba69764..e1232a768 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -12,7 +12,7 @@ graphviz==0.19.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.13.1+cpu six --f https://k2-fsa.org/nightly/ k2==1.23.4.dev20230319+cpu.torch1.13.1 +-f https://k2-fsa.github.io/k2/cpu.html k2==1.24.4.dev20231022+cpu.torch1.13.1 git+https://github.com/lhotse-speech/lhotse kaldilm==1.11 From 3fb99400cf2c691f5c666fecd1415340820364a6 Mon Sep 17 00:00:00 2001 From: hairyputtar <148847552+hairyputtar@users.noreply.github.com> Date: Tue, 24 Oct 2023 13:17:25 +0530 Subject: [PATCH 082/113] fix typos (#1336) * fix typo * fix typo * Update pruned_transducer_stateless.rst --- docs/source/contributing/how-to-create-a-recipe.rst | 2 +- docs/source/recipes/Streaming-ASR/introduction.rst | 2 +- .../librispeech/pruned_transducer_stateless.rst | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/contributing/how-to-create-a-recipe.rst b/docs/source/contributing/how-to-create-a-recipe.rst index a30fb9056..168a856c3 100644 --- a/docs/source/contributing/how-to-create-a-recipe.rst +++ b/docs/source/contributing/how-to-create-a-recipe.rst @@ -3,7 +3,7 @@ How to create a recipe .. HINT:: - Please read :ref:`follow the code style` to adjust your code sytle. + Please read :ref:`follow the code style` to adjust your code style. .. CAUTION:: diff --git a/docs/source/recipes/Streaming-ASR/introduction.rst b/docs/source/recipes/Streaming-ASR/introduction.rst index ac77a51d1..28f5b8fbf 100644 --- a/docs/source/recipes/Streaming-ASR/introduction.rst +++ b/docs/source/recipes/Streaming-ASR/introduction.rst @@ -32,7 +32,7 @@ In icefall, we implement the streaming conformer the way just like what `WeNet < .. HINT:: If you want to modify a non-streaming conformer recipe to support both streaming and non-streaming, please refer to `this pull request `_. After adding the code needed by streaming training, - you have to re-train it with the extra arguments metioned in the docs above to get a streaming model. + you have to re-train it with the extra arguments mentioned in the docs above to get a streaming model. Streaming Emformer diff --git a/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst index 2ca70bcf3..d6e424e2f 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -584,7 +584,7 @@ The following shows two examples (for the two types of checkpoints): - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and `espnet/nets/beam_search_transducer.py `_ - is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to + is used as a reference. Basically, it keeps topk states for each frame, and expands the kept states with their own contexts to next frame. - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it @@ -648,7 +648,7 @@ command to extract ``model.state_dict()``. .. caution:: ``--streaming-model`` and ``--causal-convolution`` require to be True to export - a streaming mdoel. + a streaming model. It will generate a file ``./pruned_transducer_stateless4/exp/pretrained.pt``. @@ -697,7 +697,7 @@ Export model using ``torch.jit.script()`` .. caution:: ``--streaming-model`` and ``--causal-convolution`` require to be True to export - a streaming mdoel. + a streaming model. It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later load it by ``torch.jit.load("cpu_jit.pt")``. From d76c3fe4726ccf7f1f53f5e0f0607aa3dfec12c0 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 24 Oct 2023 16:24:46 +0800 Subject: [PATCH 083/113] Migrate zipformer model to other Chinese datasets (#1216) added zipformer recipe for AISHELL-1 --- ...pruned-transducer-stateless3-2022-06-20.sh | 2 +- .../run-aishell-zipformer-2023-10-24.sh | 103 ++ .../run-aishell-zipformer-2023-10-24.yml | 95 ++ egs/aidatatang_200zh/ASR/prepare.sh | 4 +- .../asr_datamodule.py | 3 +- egs/aishell/ASR/README.md | 6 +- egs/aishell/ASR/RESULTS.md | 158 +- egs/aishell/ASR/prepare.sh | 3 +- egs/aishell/ASR/zipformer/__init__.py | 0 egs/aishell/ASR/zipformer/asr_datamodule.py | 1 + egs/aishell/ASR/zipformer/beam_search.py | 1 + egs/aishell/ASR/zipformer/decode.py | 814 ++++++++++ egs/aishell/ASR/zipformer/decode_stream.py | 1 + egs/aishell/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + .../ASR/zipformer/export-onnx-streaming.py | 1 + egs/aishell/ASR/zipformer/export-onnx.py | 1 + egs/aishell/ASR/zipformer/export.py | 1 + egs/aishell/ASR/zipformer/jit_pretrained.py | 1 + .../ASR/zipformer/jit_pretrained_streaming.py | 1 + egs/aishell/ASR/zipformer/joiner.py | 1 + egs/aishell/ASR/zipformer/model.py | 1 + egs/aishell/ASR/zipformer/onnx_check.py | 1 + egs/aishell/ASR/zipformer/onnx_decode.py | 286 ++++ .../zipformer/onnx_pretrained-streaming.py | 1 + egs/aishell/ASR/zipformer/onnx_pretrained.py | 1 + egs/aishell/ASR/zipformer/optim.py | 1 + egs/aishell/ASR/zipformer/pretrained.py | 1 + egs/aishell/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + .../ASR/zipformer/streaming_beam_search.py | 1 + egs/aishell/ASR/zipformer/streaming_decode.py | 880 +++++++++++ egs/aishell/ASR/zipformer/subsampling.py | 1 + egs/aishell/ASR/zipformer/train.py | 1350 +++++++++++++++++ egs/aishell/ASR/zipformer/zipformer.py | 1 + egs/aishell2/ASR/README.md | 6 +- egs/aishell2/ASR/RESULTS.md | 8 +- egs/aishell2/ASR/prepare.sh | 6 +- egs/aishell4/ASR/README.md | 6 +- egs/aishell4/ASR/prepare.sh | 4 +- .../asr_datamodule.py | 2 +- .../ASR/pruned_transducer_stateless5/local | 1 + 42 files changed, 3741 insertions(+), 18 deletions(-) create mode 100755 .github/scripts/run-aishell-zipformer-2023-10-24.sh create mode 100644 .github/workflows/run-aishell-zipformer-2023-10-24.yml create mode 100644 egs/aishell/ASR/zipformer/__init__.py create mode 120000 egs/aishell/ASR/zipformer/asr_datamodule.py create mode 120000 egs/aishell/ASR/zipformer/beam_search.py create mode 100755 egs/aishell/ASR/zipformer/decode.py create mode 120000 egs/aishell/ASR/zipformer/decode_stream.py create mode 120000 egs/aishell/ASR/zipformer/decoder.py create mode 120000 egs/aishell/ASR/zipformer/encoder_interface.py create mode 120000 egs/aishell/ASR/zipformer/export-onnx-streaming.py create mode 120000 egs/aishell/ASR/zipformer/export-onnx.py create mode 120000 egs/aishell/ASR/zipformer/export.py create mode 120000 egs/aishell/ASR/zipformer/jit_pretrained.py create mode 120000 egs/aishell/ASR/zipformer/jit_pretrained_streaming.py create mode 120000 egs/aishell/ASR/zipformer/joiner.py create mode 120000 egs/aishell/ASR/zipformer/model.py create mode 120000 egs/aishell/ASR/zipformer/onnx_check.py create mode 100755 egs/aishell/ASR/zipformer/onnx_decode.py create mode 120000 egs/aishell/ASR/zipformer/onnx_pretrained-streaming.py create mode 120000 egs/aishell/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/aishell/ASR/zipformer/optim.py create mode 120000 egs/aishell/ASR/zipformer/pretrained.py create mode 120000 egs/aishell/ASR/zipformer/scaling.py create mode 120000 egs/aishell/ASR/zipformer/scaling_converter.py create mode 120000 egs/aishell/ASR/zipformer/streaming_beam_search.py create mode 100755 egs/aishell/ASR/zipformer/streaming_decode.py create mode 120000 egs/aishell/ASR/zipformer/subsampling.py create mode 100755 egs/aishell/ASR/zipformer/train.py create mode 120000 egs/aishell/ASR/zipformer/zipformer.py create mode 120000 egs/aishell4/ASR/pruned_transducer_stateless5/local diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh index 4c393f6be..c3640cfde 100755 --- a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh +++ b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh @@ -18,8 +18,8 @@ log "Downloading pre-commputed fbank from $fbank_url" git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests ln -s $PWD/aishell-test-dev-manifests/data . -log "Downloading pre-trained model from $repo_url" repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 +log "Downloading pre-trained model from $repo_url" git clone $repo_url repo=$(basename $repo_url) diff --git a/.github/scripts/run-aishell-zipformer-2023-10-24.sh b/.github/scripts/run-aishell-zipformer-2023-10-24.sh new file mode 100755 index 000000000..865e29799 --- /dev/null +++ b/.github/scripts/run-aishell-zipformer-2023-10-24.sh @@ -0,0 +1,103 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/aishell/ASR + +git lfs install + +fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests +log "Downloading pre-commputed fbank from $fbank_url" + +git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests +ln -s $PWD/aishell-test-dev-manifests/data . + +log "=======================" +log "CI testing large model" +repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-large-2023-10-24/ +log "Downloading pre-trained model from $repo_url" +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +for method in modified_beam_search greedy_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --context-size 1 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_char/tokens.txt \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav +done + +log "=======================" +log "CI testing medium model" +repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-2023-10-24/ +log "Downloading pre-trained model from $repo_url" +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + + +for method in modified_beam_search greedy_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --context-size 1 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_char/tokens.txt \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav +done + + +log "=======================" +log "CI testing small model" +repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-small-2023-10-24/ +log "Downloading pre-trained model from $repo_url" +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + + +for method in modified_beam_search greedy_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --context-size 1 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_char/tokens.txt \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav +done + diff --git a/.github/workflows/run-aishell-zipformer-2023-10-24.yml b/.github/workflows/run-aishell-zipformer-2023-10-24.yml new file mode 100644 index 000000000..f2fb44a5f --- /dev/null +++ b/.github/workflows/run-aishell-zipformer-2023-10-24.yml @@ -0,0 +1,95 @@ +# Copyright 2023 Zengrui Jin (Xiaomi Corp.) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-aishell-zipformer-2023-10-24 + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: run_aishell_zipformer_2023_10_24-${{ github.ref }} + cancel-in-progress: true + +jobs: + run_aishell_zipformer_2023_10_24: + if: github.event.label.name == 'ready' || github.event.label.name == 'zipformer' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2023-05-22 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-aishell-zipformer-2023-10-24.sh + + \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh index 2eb0b3718..40ee2eb97 100755 --- a/egs/aidatatang_200zh/ASR/prepare.sh +++ b/egs/aidatatang_200zh/ASR/prepare.sh @@ -7,6 +7,8 @@ set -eou pipefail stage=-1 stop_stage=100 +perturb_speed=true + # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -77,7 +79,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Compute fbank for aidatatang_200zh" if [ ! -f data/fbank/.aidatatang_200zh.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aidatatang_200zh.py --perturb-speed True + ./local/compute_fbank_aidatatang_200zh.py --perturb-speed ${perturb_speed} touch data/fbank/.aidatatang_200zh.done fi fi diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index 3667c2ad0..d491996b2 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -102,7 +102,7 @@ class Aidatatang_200zhAsrDataModule: group.add_argument( "--bucketing-sampler", type=str2bool, - default=True, + default=False, help="When enabled, the batches will come from buckets of " "similar duration (saves padding frames).", ) @@ -289,6 +289,7 @@ class Aidatatang_200zhAsrDataModule: shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, drop_last=True, + buffer_size=50000, ) else: logging.info("Using SimpleCutSampler.") diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md index b9064cede..176f065e5 100644 --- a/egs/aishell/ASR/README.md +++ b/egs/aishell/ASR/README.md @@ -1,10 +1,12 @@ # Introduction -Please refer to -for how to run models in this recipe. +Please refer to for how to run models in this recipe. +Aishell is an open-source Chinese Mandarin speech corpus published by Beijing Shell Shell Technology Co., Ltd. +400 people from different accent areas in China are invited to participate in the recording, which is conducted in a quiet indoor environment using high fidelity microphone and downsampled to 16kHz. The manual transcription accuracy is above 95%, through professional speech annotation and strict quality inspection. The data is free for academic use. We hope to provide moderate amount of data for new researchers in the field of speech recognition. +(From [Open Speech and Language Resources](https://www.openslr.org/33/)) # Transducers diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index a2d32013a..0b22f41a1 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -1,6 +1,162 @@ ## Results -### Aishell training result(Stateless Transducer) +### Aishell training result (Stateless Transducer) + +#### Zipformer (Non-streaming) + +[./zipformer](./zipformer) + +It's reworked Zipformer with Pruned RNNT loss. +**Caution**: It uses `--context-size=1`. + +##### normal-scaled model, number of model parameters: 73412551, i.e., 73.41 M + +| | test | dev | comment | +|------------------------|------|------|-----------------------------------------| +| greedy search | 4.67 | 4.37 | --epoch 55 --avg 17 | +| modified beam search | 4.40 | 4.13 | --epoch 55 --avg 17 | +| fast beam search | 4.60 | 4.31 | --epoch 55 --avg 17 | + +Command for training is: +```bash +./prepare.sh + +export CUDA_VISIBLE_DEVICES="0,1" + +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 60 \ + --start-epoch 1 \ + --use-fp16 1 \ + --context-size 1 \ + --enable-musan 0 \ + --exp-dir zipformer/exp \ + --max-duration 1000 \ + --enable-musan 0 \ + --base-lr 0.045 \ + --lr-batches 7500 \ + --lr-epochs 18 \ + --spec-aug-time-warp-factor 20 +``` + +Command for decoding is: +```bash +for m in greedy_search modified_beam_search fast_beam_search ; do + ./zipformer/decode.py \ + --epoch 55 \ + --avg 17 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --context-size 1 \ + --decoding-method $m +done +``` +Pretrained models, training logs, decoding logs, tensorboard and decoding results +are available at + + + +##### small-scaled model, number of model parameters: 30167139, i.e., 30.17 M + +| | test | dev | comment | +|------------------------|------|------|-----------------------------------------| +| greedy search | 4.97 | 4.67 | --epoch 55 --avg 21 | +| modified beam search | 4.67 | 4.40 | --epoch 55 --avg 21 | +| fast beam search | 4.85 | 4.61 | --epoch 55 --avg 21 | + +Command for training is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" + +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 60 \ + --start-epoch 1 \ + --use-fp16 1 \ + --context-size 1 \ + --exp-dir zipformer/exp-small \ + --enable-musan 0 \ + --base-lr 0.045 \ + --lr-batches 7500 \ + --lr-epochs 18 \ + --spec-aug-time-warp-factor 20 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + --max-duration 1200 +``` + +Command for decoding is: +```bash +for m in greedy_search modified_beam_search fast_beam_search ; do + ./zipformer/decode.py \ + --epoch 55 \ + --avg 21 \ + --exp-dir ./zipformer/exp-small \ + --lang-dir data/lang_char \ + --context-size 1 \ + --decoding-method $m \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 +done +``` + +Pretrained models, training logs, decoding logs, tensorboard and decoding results +are available at + + +##### large-scaled model, number of model parameters: 157285130, i.e., 157.29 M + +| | test | dev | comment | +|------------------------|------|------|-----------------------------------------| +| greedy search | 4.49 | 4.22 | --epoch 56 --avg 23 | +| modified beam search | 4.28 | 4.03 | --epoch 56 --avg 23 | +| fast beam search | 4.44 | 4.18 | --epoch 56 --avg 23 | + +Command for training is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" + +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 60 \ + --use-fp16 1 \ + --context-size 1 \ + --exp-dir ./zipformer/exp-large \ + --enable-musan 0 \ + --lr-batches 7500 \ + --lr-epochs 18 \ + --spec-aug-time-warp-factor 20 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 800 +``` + +Command for decoding is: +```bash +for m in greedy_search modified_beam_search fast_beam_search ; do + ./zipformer/decode.py \ + --epoch 56 \ + --avg 23 \ + --exp-dir ./zipformer/exp-large \ + --lang-dir data/lang_char \ + --context-size 1 \ + --decoding-method $m \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 +done +``` + +Pretrained models, training logs, decoding logs, tensorboard and decoding results +are available at + #### Pruned transducer stateless 7 streaming [./pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index d5dbe5726..4feed55a8 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -8,6 +8,7 @@ set -eou pipefail nj=15 stage=-1 stop_stage=11 +perturb_speed=true # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -114,7 +115,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute fbank for aishell" if [ ! -f data/fbank/.aishell.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aishell.py --perturb-speed True + ./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} touch data/fbank/.aishell.done fi fi diff --git a/egs/aishell/ASR/zipformer/__init__.py b/egs/aishell/ASR/zipformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/aishell/ASR/zipformer/asr_datamodule.py b/egs/aishell/ASR/zipformer/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/aishell/ASR/zipformer/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/beam_search.py b/egs/aishell/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/aishell/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/decode.py b/egs/aishell/ASR/zipformer/decode.py new file mode 100755 index 000000000..1968904ae --- /dev/null +++ b/egs/aishell/ASR/zipformer/decode.py @@ -0,0 +1,814 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) modified beam search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(3) fast beam search (trivial_graph) +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(4) fast beam search (LG) +./zipformer/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest_oracle + If you use fast_beam_search_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "fast_beam_search_LG": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ilme_scale=params.ilme_scale, + ) + for hyp in hyp_tokens: + sentence = "".join([lexicon.word_table[i] for i in hyp]) + hyps.append(list(sentence)) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + beam=params.beam_size, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append([lexicon.token_table[idx] for idx in hyp]) + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + return {"greedy_search_" + key: hyps} + elif "fast_beam_search" in params.decoding_method: + key += f"_beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ilme_scale_{params.ilme_scale}" + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}_" + key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest_oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ilme_scale_{params.ilme_scale}" + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if "LG" in params.decoding_method: + lexicon = Lexicon(params.lang_dir) + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + dev_cuts = aishell.valid_cuts() + dev_cuts = dev_cuts.filter(remove_short_utt) + dev_dl = aishell.valid_dataloaders(dev_cuts) + + test_cuts = aishell.test_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = aishell.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer/decode_stream.py b/egs/aishell/ASR/zipformer/decode_stream.py new file mode 120000 index 000000000..b8d8ddfc4 --- /dev/null +++ b/egs/aishell/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/decoder.py b/egs/aishell/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/aishell/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/encoder_interface.py b/egs/aishell/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/aishell/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/export-onnx-streaming.py b/egs/aishell/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/aishell/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/export-onnx.py b/egs/aishell/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/aishell/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/export.py b/egs/aishell/ASR/zipformer/export.py new file mode 120000 index 000000000..dfc1bec08 --- /dev/null +++ b/egs/aishell/ASR/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/jit_pretrained.py b/egs/aishell/ASR/zipformer/jit_pretrained.py new file mode 120000 index 000000000..25108391f --- /dev/null +++ b/egs/aishell/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/jit_pretrained_streaming.py b/egs/aishell/ASR/zipformer/jit_pretrained_streaming.py new file mode 120000 index 000000000..1962351e9 --- /dev/null +++ b/egs/aishell/ASR/zipformer/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/joiner.py b/egs/aishell/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/aishell/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/model.py b/egs/aishell/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/aishell/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/onnx_check.py b/egs/aishell/ASR/zipformer/onnx_check.py new file mode 120000 index 000000000..f3dd42004 --- /dev/null +++ b/egs/aishell/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/onnx_decode.py b/egs/aishell/ASR/zipformer/onnx_decode.py new file mode 100755 index 000000000..17c6eceb4 --- /dev/null +++ b/egs/aishell/ASR/zipformer/onnx_decode.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. +""" + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from lhotse.cut import Cut +from onnx_pretrained import OnnxModel, greedy_search + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, token_table: k2.SymbolTable, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + token_table: + Mapping ids to tokens. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [[token_table[h] for h in hyp] for hyp in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + token_table: k2.SymbolTable, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + token_table: + Mapping ids to tokens. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = list(ref_text) + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + token_table = k2.SymbolTable.from_file(args.tokens) + assert token_table[0] == "" + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + + aishell = AishellAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + dev_cuts = aishell.valid_cuts() + dev_cuts = dev_cuts.filter(remove_short_utt) + dev_dl = aishell.valid_dataloaders(dev_cuts) + + test_cuts = aishell.test_net_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = aishell.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset( + dl=test_dl, model=model, token_table=token_table + ) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer/onnx_pretrained-streaming.py b/egs/aishell/ASR/zipformer/onnx_pretrained-streaming.py new file mode 120000 index 000000000..cfea104c2 --- /dev/null +++ b/egs/aishell/ASR/zipformer/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/onnx_pretrained.py b/egs/aishell/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/aishell/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/optim.py b/egs/aishell/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/aishell/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/pretrained.py b/egs/aishell/ASR/zipformer/pretrained.py new file mode 120000 index 000000000..0bd71dde4 --- /dev/null +++ b/egs/aishell/ASR/zipformer/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/scaling.py b/egs/aishell/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/aishell/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/scaling_converter.py b/egs/aishell/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/aishell/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/streaming_beam_search.py b/egs/aishell/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/aishell/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/streaming_decode.py b/egs/aishell/ASR/zipformer/streaming_decode.py new file mode 100755 index 000000000..c3820447a --- /dev/null +++ b/egs/aishell/ASR/zipformer/streaming_decode.py @@ -0,0 +1,880 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import torch +from asr_datamodule import AishellAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="Path to the lang dir(containing lexicon, tokens, etc.)", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search( + model=model, + encoder_out=encoder_out, + streams=decode_streams, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + lexicon: + The Lexicon. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + if audio.max() > 1: + logging.warning( + f"The audio should be normalized to [-1, 1], audio.max : {audio.max()}." + f"Clipping to [-1, 1]." + ) + audio = np.clip(audio, -1, 1) + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + list(decode_streams[i].ground_truth.strip()), + [ + lexicon.token_table[idx] + for idx in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + [ + lexicon.token_table[idx] + for idx in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + key = f"greedy_search_{key}" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_{key}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}_{key}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + aishell = AishellAsrDataModule(args) + + dev_cuts = aishell.valid_cuts() + test_cuts = aishell.test_cuts() + + test_sets = ["dev", "test"] + test_cuts = [dev_cuts, test_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + lexicon=lexicon, + decoding_graph=decoding_graph, + ) + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer/subsampling.py b/egs/aishell/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/aishell/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py new file mode 100755 index 000000000..7e7b02829 --- /dev/null +++ b/egs/aishell/ASR/zipformer/train.py @@ -0,0 +1,1350 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 12 \ + --start-epoch 1 \ + --exp-dir zipformer/exp \ + --training-subset L + --lr-epochs 1.5 \ + --max-duration 350 + +# For mix precision training: + +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 12 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --training-subset L \ + --lr-epochs 1.5 \ + --max-duration 750 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="""Maximum left-contexts for causal training, measured in frames which will + be converted to a number of chunks. If splitting into chunks, + chunk left-context frames will be chosen randomly from this list; else not relevant.""", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="""The prune range for rnnt loss, it means how many symbols(context) + we are using to compute the loss""", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="""The scale to smooth the loss with lm + (output of prediction network) part.""", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="""The scale to smooth the loss with am (output of encoder network) part.""", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="""To get pruning ranges, we will calculate a simple version + loss(joiner is just addition), this simple loss also uses for + training (as a regularization item). We will scale the simple loss + with this parameter before adding to the final loss.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(max(params.encoder_dim.split(","))), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, _ = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + aishell = AishellAsrDataModule(args) + + train_cuts = aishell.train_cuts() + valid_cuts = aishell.valid_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 15 seconds + # + # Caution: There is a reason to select 15.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 12.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0] + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = aishell.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_dl = aishell.valid_dataloaders(valid_cuts) + + if False and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + graph_compiler: + The compiler to encode texts to ids. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + texts = supervisions["text"] + y = graph_compiler.texts_to_ids(texts) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.lang_dir = Path(args.lang_dir) + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer/zipformer.py b/egs/aishell/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/aishell/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/aishell2/ASR/README.md b/egs/aishell2/ASR/README.md index ba38a1ec7..4e786af11 100644 --- a/egs/aishell2/ASR/README.md +++ b/egs/aishell2/ASR/README.md @@ -1,7 +1,11 @@ # Introduction -This recipe includes some different ASR models trained with Aishell2. +This recipe contains various different ASR models trained with Aishell2. + +In AISHELL-2, 1000 hours of clean read-speech data from iOS is published, which is free for academic usage. On top of AISHELL-2 corpus, an improved recipe is developed and released, containing key components for industrial applications, such as Chinese word segmentation, flexible vocabulary expension and phone set transformation etc. Pipelines support various state-of-the-art techniques, such as time-delayed neural networks and Lattic-Free MMI objective funciton. In addition, we also release dev and test data from other channels (Android and Mic). + +(From [AISHELL-2: Transforming Mandarin ASR Research Into Industrial Scale](https://arxiv.org/abs/1808.10583)) [./RESULTS.md](./RESULTS.md) contains the latest results. diff --git a/egs/aishell2/ASR/RESULTS.md b/egs/aishell2/ASR/RESULTS.md index 7114bd5f5..32ad74b50 100644 --- a/egs/aishell2/ASR/RESULTS.md +++ b/egs/aishell2/ASR/RESULTS.md @@ -1,8 +1,8 @@ ## Results -### Aishell2 char-based training results (Pruned Transducer 5) +### Aishell2 char-based training results -#### 2022-07-11 +#### Pruned transducer stateless 5 Using the codes from this commit https://github.com/k2-fsa/icefall/pull/465. @@ -41,9 +41,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" The decoding command is: ```bash -for method in greedy_search modified_beam_search \ - fast_beam_search fast_beam_search_nbest \ - fast_beam_search_nbest_oracle fast_beam_search_nbest_LG; do +for method in greedy_search modified_beam_search fast_beam_search fast_beam_search_nbest fast_beam_search_nbest_oracle fast_beam_search_nbest_LG; do ./pruned_transducer_stateless5/decode.py \ --epoch 25 \ --avg 5 \ diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh index 42631c864..6eb6268f5 100755 --- a/egs/aishell2/ASR/prepare.sh +++ b/egs/aishell2/ASR/prepare.sh @@ -7,7 +7,9 @@ set -eou pipefail nj=30 stage=0 -stop_stage=5 +stop_stage=7 +perturb_speed=true + # We assume dl_dir (download dir) contains the following # directories and files. If not, you need to apply aishell2 through @@ -101,7 +103,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute fbank for aishell2" if [ ! -f data/fbank/.aishell2.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aishell2.py --perturb-speed True + ./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} touch data/fbank/.aishell2.done fi fi diff --git a/egs/aishell4/ASR/README.md b/egs/aishell4/ASR/README.md index 3744032f8..67fa17790 100644 --- a/egs/aishell4/ASR/README.md +++ b/egs/aishell4/ASR/README.md @@ -1,7 +1,11 @@ # Introduction -This recipe includes some different ASR models trained with Aishell4 (including S, M and L three subsets). +This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets). + +The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks. + +(From [Open Speech and Language Resources](https://www.openslr.org/111/)) [./RESULTS.md](./RESULTS.md) contains the latest results. diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh index 1b1ec0005..361cc26ab 100755 --- a/egs/aishell4/ASR/prepare.sh +++ b/egs/aishell4/ASR/prepare.sh @@ -7,6 +7,8 @@ set -eou pipefail stage=-1 stop_stage=100 +perturb_speed=true + # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -107,7 +109,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Compute fbank for aishell4" if [ ! -f data/fbank/.aishell4.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aishell4.py --perturb-speed True + ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} touch data/fbank/.aishell4.done fi fi diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index 4ad98fb51..e6db2651f 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -306,7 +306,7 @@ class Aishell4AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=30000, + buffer_size=100000, drop_last=self.args.drop_last, ) else: diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/local b/egs/aishell4/ASR/pruned_transducer_stateless5/local new file mode 120000 index 000000000..c820590c5 --- /dev/null +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/local @@ -0,0 +1 @@ +../local \ No newline at end of file From f82bccfd63d4f02fbe5050e3c2d972dc69656215 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 24 Oct 2023 19:04:09 +0800 Subject: [PATCH 084/113] Support CTC decoding for `multi-zh_hans` recipe (#1313) --- .../scripts/run-multi-zh_hans-zipformer.sh | 44 ++ .../workflows/run-multi-zh_hans-zipformer.yml | 2 +- egs/multi_zh-hans/ASR/RESULTS.md | 43 +- egs/multi_zh-hans/ASR/zipformer/ctc_decode.py | 625 ++++++++++++++++++ 4 files changed, 709 insertions(+), 5 deletions(-) create mode 100755 egs/multi_zh-hans/ASR/zipformer/ctc_decode.py diff --git a/.github/scripts/run-multi-zh_hans-zipformer.sh b/.github/scripts/run-multi-zh_hans-zipformer.sh index 2bc3137d8..dd32a94f8 100755 --- a/.github/scripts/run-multi-zh_hans-zipformer.sh +++ b/.github/scripts/run-multi-zh_hans-zipformer.sh @@ -10,6 +10,7 @@ log() { cd egs/multi_zh-hans/ASR +log "==== Test icefall-asr-multi-zh-hans-zipformer-2023-9-2 ====" repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ log "Downloading pre-trained model from $repo_url" @@ -49,3 +50,46 @@ for method in modified_beam_search fast_beam_search; do $repo/test_wavs/DEV_T0000000001.wav \ $repo/test_wavs/DEV_T0000000002.wav done + +log "==== Test icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 ====" +repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/ + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s epoch-20.pt epoch-99.pt +popd + +ls -lh $repo/exp/*.pt + + +./zipformer/pretrained.py \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --use-ctc 1 \ + --method greedy_search \ +$repo/test_wavs/DEV_T0000000000.wav \ +$repo/test_wavs/DEV_T0000000001.wav \ +$repo/test_wavs/DEV_T0000000002.wav + +for method in modified_beam_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --beam-size 4 \ + --use-ctc 1 \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav +done \ No newline at end of file diff --git a/.github/workflows/run-multi-zh_hans-zipformer.yml b/.github/workflows/run-multi-zh_hans-zipformer.yml index 4ec81585f..72c0775a7 100644 --- a/.github/workflows/run-multi-zh_hans-zipformer.yml +++ b/.github/workflows/run-multi-zh_hans-zipformer.yml @@ -29,7 +29,7 @@ concurrency: jobs: run_multi-zh_hans_zipformer: - if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' || github.event.label.name == 'zipformer' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/egs/multi_zh-hans/ASR/RESULTS.md b/egs/multi_zh-hans/ASR/RESULTS.md index 31fbd9700..5133229a7 100644 --- a/egs/multi_zh-hans/ASR/RESULTS.md +++ b/egs/multi_zh-hans/ASR/RESULTS.md @@ -4,6 +4,41 @@ This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall. +#### Non-streaming (with CTC head) + +Best results (num of params : ~69M): + +The training command: + +``` +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 20 \ + --use-fp16 1 \ + --max-duration 600 \ + --num-workers 8 \ + --use-ctc 1 +``` + +The decoding command: + +``` +./zipformer/decode.py \ + --epoch 20 \ + --avg 1 \ + --use-ctc 1 +``` + +Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using BPE model ( # tokens is 2000, byte fallback enabled). + +| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | +|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| +| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net | +| CTC Decoding | 14.57 | 15.26 | 72.85 | 69.70 | 12.87 | 13.76 | 23.56 | 25.55 | 71.75 | 22.35 | 19.34 | 42.38 | 26.90 | 48.71 | 64.88 | 67.29 | 54.24 | +| Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 | + +Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/ + #### Non-streaming Best results (num of params : ~69M): @@ -29,10 +64,10 @@ The decoding command: Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled). -| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | +| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | |--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| -| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | -| | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 | +| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net | +| Greedy Search | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 | -The pre-trained model is available here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2 +Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ diff --git a/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py new file mode 100755 index 000000000..a7cd7ce43 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py @@ -0,0 +1,625 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-decoding +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method ctc-decoding + +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from lhotse.cut import Cut +from multi_dataset import MultiDataset +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import get_lattice, one_best_decoding +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_2000/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_2000", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ("ctc-decoding",) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=True, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + + G = None + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args.manifest_dir) + + test_sets_cuts = multi_dataset.test_cuts() + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" + ) + return T > 0 + + test_sets = test_sets_cuts.keys() + test_dl = [ + data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) + for cuts_name in test_sets + ] + + for test_set, test_dl in zip(test_sets, test_dl): + logging.info(f"Start decoding test set: {test_set}") + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() From 1814bbb0e7afcfcfe495322d2abd4fbfb21510c4 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 25 Oct 2023 00:03:33 +0800 Subject: [PATCH 085/113] typo fixed (#1334) --- egs/aishell/ASR/prepare.sh | 2 +- egs/aishell2/ASR/prepare.sh | 2 +- egs/gigaspeech/ASR/prepare.sh | 2 +- egs/librispeech/ASR/prepare.sh | 2 +- egs/librispeech/WSASR/prepare.sh | 2 +- egs/mgb2/ASR/prepare.sh | 2 +- egs/swbd/ASR/prepare.sh | 2 +- egs/tal_csasr/ASR/prepare.sh | 2 +- egs/tedlium3/ASR/prepare.sh | 2 +- egs/wenetspeech/ASR/prepare.sh | 2 +- egs/xbmu_amdo31/ASR/prepare.sh | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 4feed55a8..d36dc5ed3 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -243,7 +243,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then -lm data/lm/3-gram.unpruned.arpa fi - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then # It is used in building HLG diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh index 6eb6268f5..a5eb9bd13 100755 --- a/egs/aishell2/ASR/prepare.sh +++ b/egs/aishell2/ASR/prepare.sh @@ -159,7 +159,7 @@ fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Prepare G" - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm if [ ! -f ${lang_char_dir}/3-gram.unpruned.arpa ]; then diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh index bd255dc6a..a23b708d7 100755 --- a/egs/gigaspeech/ASR/prepare.sh +++ b/egs/gigaspeech/ASR/prepare.sh @@ -293,7 +293,7 @@ fi if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then log "Stage 12: Prepare G" - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm mkdir -p data/lm diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 739608572..4a5072cc0 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -278,7 +278,7 @@ fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then log "Stage 8: Prepare G" - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm mkdir -p data/lm diff --git a/egs/librispeech/WSASR/prepare.sh b/egs/librispeech/WSASR/prepare.sh index f6a922fde..0d2a67259 100755 --- a/egs/librispeech/WSASR/prepare.sh +++ b/egs/librispeech/WSASR/prepare.sh @@ -193,7 +193,7 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Prepare G" - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm mkdir -p "${lm_dir}" diff --git a/egs/mgb2/ASR/prepare.sh b/egs/mgb2/ASR/prepare.sh index 899d15d97..4ea427371 100755 --- a/egs/mgb2/ASR/prepare.sh +++ b/egs/mgb2/ASR/prepare.sh @@ -188,7 +188,7 @@ fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then log "Stage 8: Prepare G" - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh index 47d12613b..6b6f4ff86 100755 --- a/egs/swbd/ASR/prepare.sh +++ b/egs/swbd/ASR/prepare.sh @@ -311,7 +311,7 @@ fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then log "Stage 8: Prepare G" lang_dir=data/lang_phone - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm mkdir -p data/lm diff --git a/egs/tal_csasr/ASR/prepare.sh b/egs/tal_csasr/ASR/prepare.sh index 352e8ba66..2de4ac8f5 100755 --- a/egs/tal_csasr/ASR/prepare.sh +++ b/egs/tal_csasr/ASR/prepare.sh @@ -150,7 +150,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi # Prepare words.txt - # We assume you have install jieba, if not, please install + # We assume you have installed jieba, if not, please install # it using: pip install jieba if [ ! -f $lang_char_dir/words.txt ]; then python -m jieba $lang_char_dir/text | sed 's/\///g;s/\s\+/ /g' > $lang_char_dir/text.seg diff --git a/egs/tedlium3/ASR/prepare.sh b/egs/tedlium3/ASR/prepare.sh index 3d90436ff..2f58ca0ee 100755 --- a/egs/tedlium3/ASR/prepare.sh +++ b/egs/tedlium3/ASR/prepare.sh @@ -172,7 +172,7 @@ fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then log "Stage 7: Prepare G" - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm mkdir -p data/lm diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index 097a59a5f..f7eb9f0d0 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -237,7 +237,7 @@ fi if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then log "Stage 17: Prepare G" # It will take about 20 minutes. - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm if [ ! -f $lang_char_dir/3-gram.unpruned.arpa ]; then python3 ./shared/make_kn_lm.py \ diff --git a/egs/xbmu_amdo31/ASR/prepare.sh b/egs/xbmu_amdo31/ASR/prepare.sh index 32ae440f7..21836840c 100755 --- a/egs/xbmu_amdo31/ASR/prepare.sh +++ b/egs/xbmu_amdo31/ASR/prepare.sh @@ -224,7 +224,7 @@ fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then log "Stage 8: Prepare G" - # We assume you have install kaldilm, if not, please install + # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm mkdir -p data/lm From dcbc7a63e117c8fdd4003bb8d998d7a0b6376aa2 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 25 Oct 2023 12:50:35 +0800 Subject: [PATCH 086/113] Update train-rnn-lm.sh (#1337) --- egs/ptb/LM/train-rnn-lm.sh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/egs/ptb/LM/train-rnn-lm.sh b/egs/ptb/LM/train-rnn-lm.sh index 29c609ee1..cb70b7856 100755 --- a/egs/ptb/LM/train-rnn-lm.sh +++ b/egs/ptb/LM/train-rnn-lm.sh @@ -37,10 +37,8 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then --world-size $world_size \ --use-fp16 0 \ --vocab-size 500 \ - \ --lm-data ./data/lm_training_bpe_500/sorted_lm_data.pt \ --lm-data-valid ./data/lm_training_bpe_500/sorted_lm_data-valid.pt \ - \ --embedding-dim 800 \ --hidden-dim 200 \ --num-layers 2 \ @@ -56,9 +54,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then --epoch $use_epoch \ --avg $use_avg \ --vocab-size 500 \ - \ --lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt \ - \ --embedding-dim 800 \ --hidden-dim 200 \ --num-layers 2 \ From 770c495484f2f244e1b54ea51fea1661f48a0a06 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 25 Oct 2023 17:14:17 +0800 Subject: [PATCH 087/113] minor fixes in the CTC decoding code (#1338) --- egs/multi_zh-hans/ASR/RESULTS.md | 4 ++-- egs/multi_zh-hans/ASR/zipformer/ctc_decode.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/multi_zh-hans/ASR/RESULTS.md b/egs/multi_zh-hans/ASR/RESULTS.md index 5133229a7..15e789604 100644 --- a/egs/multi_zh-hans/ASR/RESULTS.md +++ b/egs/multi_zh-hans/ASR/RESULTS.md @@ -33,8 +33,8 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the | Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | |--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| -| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net | -| CTC Decoding | 14.57 | 15.26 | 72.85 | 69.70 | 12.87 | 13.76 | 23.56 | 25.55 | 71.75 | 22.35 | 19.34 | 42.38 | 26.90 | 48.71 | 64.88 | 67.29 | 54.24 | +| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | +| CTC Decoding | 2.86 | 3.36 | 22.93 | 24.28 | 2.05 | 2.27 | 3.33 | 3.82 | 15.45 | 3.49 | 2.77 | 6.90 | 2.85 | 8.29 | 9.41 | 6.92 | 8.57 | | Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 | Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/ diff --git a/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py index a7cd7ce43..5143f945d 100755 --- a/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py @@ -379,7 +379,8 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() + ref_words = list(ref_text.replace(" ", "")) + hyp_words = list("".join(hyp_words)) this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) From c0a53271e2fe64dd02939bb6e2ff3a2938715b48 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Thu, 26 Oct 2023 17:35:12 +0800 Subject: [PATCH 088/113] Update Zipformer-large result on LibriSpeech (#1343) * update zipformer-large result on librispeech --- README.md | 11 +++---- egs/librispeech/ASR/RESULTS.md | 52 ++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index a14abd023..81efda32a 100644 --- a/README.md +++ b/README.md @@ -118,11 +118,12 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles #### k2 pruned RNN-T -| Encoder | Params | test-clean | test-other | -|-----------------|--------|------------|------------| -| zipformer | 65.5M | 2.21 | 4.79 | -| zipformer-small | 23.2M | 2.42 | 5.73 | -| zipformer-large | 148.4M | 2.06 | 4.63 | +| Encoder | Params | test-clean | test-other | epochs | devices | +|-----------------|--------|------------|------------|---------|------------| +| zipformer | 65.5M | 2.21 | 4.79 | 50 | 4 32G-V100 | +| zipformer-small | 23.2M | 2.42 | 5.73 | 50 | 2 32G-V100 | +| zipformer-large | 148.4M | 2.06 | 4.63 | 50 | 4 32G-V100 | +| zipformer-large | 148.4M | 2.00 | 4.38 | 174 | 8 80G-A100 | Note: No auxiliary losses are used in the training and no LMs are used in the decoding. diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index fc7fcdc26..a1808edd3 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -245,6 +245,58 @@ for m in greedy_search modified_beam_search fast_beam_search; do done ``` +##### large-scaled model, number of model parameters: 148439574, i.e., 148.4 M, trained on 8 80G-A100 GPUs + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|----------------------|------------|------------|-----------------------| +| greedy_search | 2.00 | 4.47 | --epoch 174 --avg 172 | +| modified_beam_search | 2.00 | 4.38 | --epoch 174 --avg 172 | +| fast_beam_search | 2.00 | 4.42 | --epoch 174 --avg 172 | + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 174 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large \ + --causal 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --full-libri 1 \ + --max-duration 2200 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in greedy_search modified_beam_search fast_beam_search; do + ./zipformer/decode.py \ + --epoch 174 \ + --avg 172 \ + --exp-dir zipformer/exp-large \ + --max-duration 600 \ + --causal 0 \ + --decoding-method $m \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 +done +``` + #### streaming ##### normal-scaled model, number of model parameters: 66110931, i.e., 66.11 M From 800bf4b6a2e32745e7d0c31dd78d473f1faff509 Mon Sep 17 00:00:00 2001 From: hairyputtar <148847552+hairyputtar@users.noreply.github.com> Date: Fri, 27 Oct 2023 09:16:28 +0530 Subject: [PATCH 089/113] fix more typos (#1340) * fix more typos * fix typo * fix typo * fix typo --- docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst | 2 +- docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst | 2 +- .../librispeech/pruned_transducer_stateless.rst | 2 +- docs/source/recipes/RNN-LM/librispeech/lm-training.rst | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst index 6e30ce397..aad90f9d0 100644 --- a/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst +++ b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst @@ -67,7 +67,7 @@ To run stage 2 to stage 5, use: .. HINT:: A 3-gram language model will be downloaded from huggingface, we assume you have - intalled and initialized ``git-lfs``. If not, you could install ``git-lfs`` by + installed and initialized ``git-lfs``. If not, you could install ``git-lfs`` by .. code-block:: bash diff --git a/docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst index 9eb3b11f7..8e56deb6a 100644 --- a/docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst @@ -67,7 +67,7 @@ To run stage 2 to stage 5, use: .. HINT:: A 3-gram language model will be downloaded from huggingface, we assume you have - intalled and initialized ``git-lfs``. If not, you could install ``git-lfs`` by + installed and initialized ``git-lfs``. If not, you could install ``git-lfs`` by .. code-block:: bash diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst index 1bc1dd984..f356e97e7 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -418,7 +418,7 @@ The following shows two examples (for two types of checkpoints): - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and `espnet/nets/beam_search_transducer.py `_ - is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to + is used as a reference. Basically, it keeps topk states for each frame, and expands the kept states with their own contexts to next frame. - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it diff --git a/docs/source/recipes/RNN-LM/librispeech/lm-training.rst b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst index 736120275..46499a374 100644 --- a/docs/source/recipes/RNN-LM/librispeech/lm-training.rst +++ b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst @@ -1,6 +1,6 @@ .. _train_nnlm: -Train an RNN langugage model +Train an RNN language model ====================================== If you have enough text data, you can train a neural network language model (NNLM) to improve From ea78b328575f9533c7d34db6f9cd0f44b09b6092 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 27 Oct 2023 13:35:43 +0800 Subject: [PATCH 090/113] minor fixes (#1345) --- egs/tedlium3/ASR/zipformer/decode.py | 4 ++-- egs/tedlium3/ASR/zipformer/train.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/tedlium3/ASR/zipformer/decode.py b/egs/tedlium3/ASR/zipformer/decode.py index ea1cbba1b..2c4123c20 100755 --- a/egs/tedlium3/ASR/zipformer/decode.py +++ b/egs/tedlium3/ASR/zipformer/decode.py @@ -116,7 +116,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import add_model_arguments, get_params, get_transducer_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, @@ -695,7 +695,7 @@ def main(): logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_model(params) if not params.use_averaged_model: if params.iter > 0: diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py index 33d03908c..5ad01df27 100755 --- a/egs/tedlium3/ASR/zipformer/train.py +++ b/egs/tedlium3/ASR/zipformer/train.py @@ -586,7 +586,7 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: return joiner -def get_transducer_model(params: AttributeDict) -> nn.Module: +def get_model(params: AttributeDict) -> nn.Module: encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) decoder = get_decoder_model(params) @@ -1083,7 +1083,7 @@ def run(rank, world_size, args): logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_model(params) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") From 5cebecf2dcebbfb7284cc2577d1e50a33933c663 Mon Sep 17 00:00:00 2001 From: Shreyas0410 <70795867+Shreyas0410@users.noreply.github.com> Date: Fri, 27 Oct 2023 11:06:15 +0530 Subject: [PATCH 091/113] updated broken link in read.me file (#1342) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 81efda32a..15e9e17e6 100644 --- a/README.md +++ b/README.md @@ -367,7 +367,7 @@ Once you have trained a model in icefall, you may want to deploy it with C++, without Python dependencies. Please refer to the documentation - + for how to do this. We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++. From 7d56685734cbdd9170caae7fada2d64b27cab2b3 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Fri, 27 Oct 2023 01:38:09 -0400 Subject: [PATCH 092/113] [recipe] LibriSpeech zipformer_ctc (#941) * merge upstream * initial commit for zipformer_ctc * remove unwanted changes * remove changes to other recipe * fix zipformer softlink * fix for JIT export * add missing file * fix symbolic links * update results * Update RESULTS.md Address comments from @csukuangfj --------- Co-authored-by: zr_jin --- egs/librispeech/ASR/README.md | 1 + egs/librispeech/ASR/RESULTS.md | 51 +- egs/librispeech/ASR/zipformer_ctc/__init__.py | 0 .../ASR/zipformer_ctc/asr_datamodule.py | 1 + egs/librispeech/ASR/zipformer_ctc/decode.py | 886 +++++++++++++ egs/librispeech/ASR/zipformer_ctc/decoder.py | 298 +++++ .../ASR/zipformer_ctc/encoder_interface.py | 1 + egs/librispeech/ASR/zipformer_ctc/export.py | 240 ++++ .../ASR/zipformer_ctc/label_smoothing.py | 1 + egs/librispeech/ASR/zipformer_ctc/model.py | 158 +++ egs/librispeech/ASR/zipformer_ctc/optim.py | 1 + egs/librispeech/ASR/zipformer_ctc/scaling.py | 1 + .../ASR/zipformer_ctc/scaling_converter.py | 1 + .../ASR/zipformer_ctc/subsampling.py | 1 + egs/librispeech/ASR/zipformer_ctc/train.py | 1135 +++++++++++++++++ .../ASR/zipformer_ctc/transformer.py | 1 + .../ASR/zipformer_ctc/zipformer.py | 1 + 17 files changed, 2777 insertions(+), 1 deletion(-) create mode 100644 egs/librispeech/ASR/zipformer_ctc/__init__.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py create mode 100755 egs/librispeech/ASR/zipformer_ctc/decode.py create mode 100644 egs/librispeech/ASR/zipformer_ctc/decoder.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/encoder_interface.py create mode 100755 egs/librispeech/ASR/zipformer_ctc/export.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/label_smoothing.py create mode 100644 egs/librispeech/ASR/zipformer_ctc/model.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/optim.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/scaling.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/scaling_converter.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/subsampling.py create mode 100755 egs/librispeech/ASR/zipformer_ctc/train.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/transformer.py create mode 120000 egs/librispeech/ASR/zipformer_ctc/zipformer.py diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index f42750da9..1c8930818 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -47,6 +47,7 @@ We place an additional Conv1d layer right after the input embedding layer. | `conformer-ctc` | Conformer | Use auxiliary attention head | | `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head | | `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty | +| `zipformer-ctc` | Zipformer | Use auxiliary attention head | | `zipformer` | Upgraded Zipformer | Use auxiliary transducer head | The latest recipe | # MMI diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index a1808edd3..ebf5e89c4 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -375,6 +375,55 @@ for m in greedy_search modified_beam_search fast_beam_search; do done ``` +### Zipformer CTC + +#### [zipformer_ctc](./zipformer_ctc) + +See for more details. + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +Number of model parameters: 86083707, i.e., 86.08 M + +| decoding method | test-clean | test-other | comment | +|-------------------------|------------|------------|---------------------| +| ctc-decoding | 2.50 | 5.86 | --epoch 30 --avg 9 | +| whole-lattice-rescoring | 2.44 | 5.38 | --epoch 30 --avg 9 | +| attention-rescoring | 2.35 | 5.16 | --epoch 30 --avg 9 | +| 1best | 2.01 | 4.61 | --epoch 30 --avg 9 | + +The training commands are: +```bash + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./zipformer_ctc/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer_ctc/exp \ + --full-libri 1 \ + --max-duration 1000 \ + --master-port 12345 +``` + +The tensorboard log can be found at: + + +The decoding command is: + +```bash +./zipformer_ctc/decode.py \ + --epoch 30 --avg 9 --use-averaged-model True \ + --exp-dir zipformer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --lm-dir data/lm \ + --method ctc-decoding +``` + ### pruned_transducer_stateless7 (Fine-tune with mux) See for more details. @@ -616,7 +665,6 @@ for m in greedy_search modified_beam_search fast_beam_search; do done ``` - #### Smaller model We also provide a very small version (only 6.1M parameters) of this setup. The training command for the small model is: @@ -663,6 +711,7 @@ This small model achieves the following WERs on GigaSpeech test and dev sets: You can find the tensorboard logs at . + ### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer) #### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) diff --git a/egs/librispeech/ASR/zipformer_ctc/__init__.py b/egs/librispeech/ASR/zipformer_ctc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py b/egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py new file mode 120000 index 000000000..fa1b8cca3 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py @@ -0,0 +1 @@ +../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/decode.py b/egs/librispeech/ASR/zipformer_ctc/decode.py new file mode 100755 index 000000000..7f605e2c8 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/decode.py @@ -0,0 +1,886 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_ctc_model, get_params +from transformer import encoder_padding_mask + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_rnn_lm, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import ( + AttributeDict, + get_texts, + load_averaged_model, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=77, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=55, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume + you have trained an RNN LM using ./rnn_lm/train.py + - (7) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, + ) + + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + nnet_output, _ = model.encoder(feature, feature_lens) + ctc_output = model.ctc_output(nnet_output) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + + nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(nnet_output.size(0), supervisions) + mask = mask.to(nnet_output.device) if mask is not None else None + mmodel = model.decoder.module if hasattr(model.decoder, "module") else model.decoder + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.method == "attention-decoder": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=mmodel, + memory=nnet_output, + memory_key_padding_mask=mask, + sos_id=sos_id, + eos_id=eos_id, + nbest_scale=params.nbest_scale, + ) + elif params.method == "rnn-lm": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_rnn_lm( + lattice=rescored_lattice, + num_paths=params.num_paths, + rnn_lm_model=rnn_lm_model, + model=mmodel, + memory=nnet_output, + memory_key_padding_mask=mask, + sos_id=sos_id, + eos_id=eos_id, + blank_id=0, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert len(results) > 0, "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]], +): + if params.method in ("attention-decoder", "rnn-lm"): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + params.vocab_size = num_classes + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + params.num_classes = num_classes + params.sos_id = sos_id + params.eos_id = eos_id + + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_ctc_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + rnn_lm_model = None + if params.method == "rnn-lm": + rnn_lm_model = RnnLmModel( + vocab_size=params.num_classes, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + if params.rnn_lm_avg == 1: + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + else: + rnn_lm_model = load_averaged_model( + params.rnn_lm_exp_dir, + rnn_lm_model, + params.rnn_lm_epoch, + params.rnn_lm_avg, + device, + ) + rnn_lm_model.eval() + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer_ctc/decoder.py b/egs/librispeech/ASR/zipformer_ctc/decoder.py new file mode 100644 index 000000000..8dec048a1 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/decoder.py @@ -0,0 +1,298 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from label_smoothing import LabelSmoothingLoss +from torch.nn.utils.rnn import pad_sequence +from transformer import PositionalEncoding, TransformerDecoderLayer + + +class Decoder(nn.Module): + """This class implements Transformer based decoder for an attention-based encoder-decoder + model. + """ + + def __init__( + self, + num_layers: int, + num_classes: int, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + dropout: float = 0.1, + normalize_before: bool = True, + ): + """ + Args: + num_layers: + Number of layers. + num_classes: + Number of tokens of the modeling unit including blank. + d_model: + Dimension of the input embedding, and of the decoder output. + """ + super().__init__() + + if num_layers > 0: + self.decoder_num_class = num_classes # bpe model already has sos/eos symbol + + self.decoder_embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) + self.decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + decoder_norm = nn.LayerNorm(d_model) + else: + decoder_norm = None + + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_layers, + norm=decoder_norm, + ) + + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_criterion = LabelSmoothingLoss() + else: + self.decoder_criterion = None + + @torch.jit.export + def forward( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id + Returns: + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + + @torch.jit.export + def decoder_nll( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[torch.Tensor], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + Returns: + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). + """ + # The common part between this function and decoder_forward could be + # extracted as a separate function. + if isinstance(token_ids[0], torch.Tensor): + # This branch is executed by torchscript in C++. + # See https://github.com/k2-fsa/k2/pull/870 + # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 + token_ids = [tolist(t) for t in token_ids] + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction="none", + ) + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + + +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. + """ + return [[sos_id] + utt for utt in token_ids] + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. + """ + return [utt + [eos_id] for utt in token_ids] + + +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: + """Generate a length mask for input. + The masked position are filled with True, + Unmasked positions are filled with False. + Args: + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad + Returns: + Tensor: + a bool tensor of the same shape as the input tensor. + """ + ys_mask = ys_pad == ignore_id + return ys_mask + + +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + For instance, if sz is 3, it returns:: + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) + Args: + sz: mask size + Returns: + A square mask of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +def tolist(t: torch.Tensor) -> List[int]: + """Used by jit""" + return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/librispeech/ASR/zipformer_ctc/encoder_interface.py b/egs/librispeech/ASR/zipformer_ctc/encoder_interface.py new file mode 120000 index 000000000..b8529e0b7 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/export.py b/egs/librispeech/ASR/zipformer_ctc/export.py new file mode 100755 index 000000000..0ff50f128 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/export.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_ctc_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer_ctc/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""It contains language related input files such as "lexicon.txt" + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + add_model_arguments(parser) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + params.vocab_size = num_classes + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = get_ctc_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + convert_scaled_to_non_scaled(model, inplace=True) + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer_ctc/label_smoothing.py b/egs/librispeech/ASR/zipformer_ctc/label_smoothing.py new file mode 120000 index 000000000..08734abd7 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/label_smoothing.py @@ -0,0 +1 @@ +../conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/model.py b/egs/librispeech/ASR/zipformer_ctc/model.py new file mode 100644 index 000000000..2aeb8a072 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/model.py @@ -0,0 +1,158 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from transformer import encoder_padding_mask + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.utils import encode_supervisions + + +class CTCModel(nn.Module): + """It implements a CTC model with an auxiliary attention head.""" + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + encoder_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + An instance of `EncoderInterface`. The shared encoder for the CTC and attention + branches + decoder: + An instance of `nn.Module`. This is the decoder for the attention branch. + encoder_dim: + Dimension of the encoder output. + decoder_dim: + Dimension of the decoder output. + vocab_size: + Number of tokens of the modeling unit including blank. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder = encoder + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + self.decoder = decoder + + @torch.jit.ignore + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + supervisions: torch.Tensor, + graph_compiler: BpeCtcTrainingGraphCompiler, + subsampling_factor: int = 1, + beam_size: int = 10, + reduction: str = "sum", + use_double_scores: bool = False, + ) -> torch.Tensor: + """ + Args: + x: + Tensor of dimension (N, T, C) where N is the batch size, + T is the number of frames, and C is the feature dimension. + x_lens: + Tensor of dimension (N,) where N is the batch size. + supervisions: + Supervisions are used in training. + graph_compiler: + It is used to compile a decoding graph from texts. + subsampling_factor: + It is used to compute the `supervisions` for the encoder. + beam_size: + Beam size used in `k2.ctc_loss`. + reduction: + Reduction method used in `k2.ctc_loss`. + use_double_scores: + If True, use double precision in `k2.ctc_loss`. + Returns: + Return the CTC loss, attention loss, and the total number of frames. + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + + nnet_output, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + # compute ctc log-probs + ctc_output = self.ctc_output(nnet_output) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=subsampling_factor + ) + num_frames = supervision_segments[:, 2].sum().item() + + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments.cpu(), + allow_truncate=subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=beam_size, + reduction=reduction, + use_double_scores=use_double_scores, + ) + + if self.decoder is not None: + nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mmodel = ( + self.decoder.module if hasattr(self.decoder, "module") else self.decoder + ) + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + mask = encoder_padding_mask(nnet_output.size(0), supervisions) + mask = mask.to(nnet_output.device) if mask is not None else None + att_loss = mmodel.forward( + nnet_output, + mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + else: + att_loss = torch.tensor([0]) + + return ctc_loss, att_loss, num_frames diff --git a/egs/librispeech/ASR/zipformer_ctc/optim.py b/egs/librispeech/ASR/zipformer_ctc/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/scaling.py b/egs/librispeech/ASR/zipformer_ctc/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/scaling_converter.py b/egs/librispeech/ASR/zipformer_ctc/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/subsampling.py b/egs/librispeech/ASR/zipformer_ctc/subsampling.py new file mode 120000 index 000000000..6fee09e58 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/subsampling.py @@ -0,0 +1 @@ +../conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py new file mode 100755 index 000000000..f40344357 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/train.py @@ -0,0 +1,1135 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./zipformer_ctc/train.py \ + --exp-dir ./zipformer_ctc/exp \ + --world-size 4 \ + --full-libri 1 \ + --max-duration 500 \ + --num-epochs 30 +""" + +import argparse +import copy +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import CTCModel +from optim import Eden, LRScheduler, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Normalization for the input features, can be a + boolean indicating whether to do batch + normalization, or a float which means just scaling + the input features with this float value. + If given a float value, we will remove batchnorm + layer in `ConvolutionModule` as well. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + num_layers=params.num_decoder_layers, + num_classes=params.vocab_size, + d_model=int(params.encoder_dims.split(",")[-1]), + ) + return decoder + + +def get_ctc_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + + model = CTCModel( + encoder=encoder, + decoder=decoder, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + with torch.set_grad_enabled(is_training): + ctc_loss, att_loss, tot_frames = model( + feature, + feature_lens, + supervisions, + graph_compiler, + subsampling_factor=params.subsampling_factor, + beam_size=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + info = MetricsTracker() + info["frames"] = tot_frames + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + assert loss.requires_grad == is_training, f"{loss.requires_grad} != {is_training}" + info["loss"] = loss.detach().cpu().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = supervisions["num_frames"].sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ) + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + params.vocab_size = num_classes + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + logging.info("About to create model") + + model = get_ctc_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 25.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: BpeCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = graph_compiler.texts_to_ids(supervisions["text"]) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer_ctc/transformer.py b/egs/librispeech/ASR/zipformer_ctc/transformer.py new file mode 120000 index 000000000..4c890cf29 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/transformer.py @@ -0,0 +1 @@ +../conformer_ctc/transformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/zipformer.py b/egs/librispeech/ASR/zipformer_ctc/zipformer.py new file mode 120000 index 000000000..79b076556 --- /dev/null +++ b/egs/librispeech/ASR/zipformer_ctc/zipformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/zipformer.py \ No newline at end of file From 161ab90dfb951a49c7fb373861317fdcb9a9e7e4 Mon Sep 17 00:00:00 2001 From: Himanshu Kumar Mahto <93067059+HimanshuMahto@users.noreply.github.com> Date: Mon, 30 Oct 2023 06:37:42 +0530 Subject: [PATCH 093/113] Enhancing the contributing.md file (#1351) --- contributing.md | 58 ++++++++++++++++++++++++------------------------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/contributing.md b/contributing.md index c8f06fdae..0a1f9936e 100644 --- a/contributing.md +++ b/contributing.md @@ -1,39 +1,37 @@ +# Contributing to Our Project -## Pre-commit hooks +Thank you for your interest in contributing to our project! We use Git pre-commit hooks to ensure code quality and consistency. Before contributing, please follow these guidelines to enable and use the pre-commit hooks. -We use [git][git] [pre-commit][pre-commit] [hooks][hooks] to check that files -going to be committed: +## Pre-Commit Hooks - - contain no trailing spaces - - are formatted with [black][black] - - are compatible to [PEP8][PEP8] (checked by [flake8][flake8]) - - end in a newline and only a newline - - contain sorted `imports` (checked by [isort][isort]) +We have set up pre-commit hooks to check that the files you're committing meet our coding and formatting standards. These checks include: -These hooks are disabled by default. Please use the following commands to enable them: +- Ensuring there are no trailing spaces. +- Formatting code with [black](https://github.com/psf/black). +- Checking compliance with PEP8 using [flake8](https://flake8.pycqa.org/). +- Verifying that files end with a newline character (and only a newline). +- Sorting imports using [isort](https://pycqa.github.io/isort/). -```bash -pip install pre-commit # run it only once -pre-commit install # run it only once, it will install all hooks +Please note that these hooks are disabled by default. To enable them, follow these steps: -# modify some files -git add -git commit # It runs all hooks automatically. +### Installation (Run only once) -# If all hooks run successfully, you can write the commit message now. Done! -# -# If any hook failed, your commit was not successful. -# Please read the error messages and make changes accordingly. -# And rerun +1. Install the `pre-commit` package using pip: + ```bash + pip install pre-commit + ``` +1. Install the Git hooks using: + ```bash + pre-commit install + ``` +### Making a Commit +Once you have enabled the pre-commit hooks, follow these steps when making a commit: +1. Make your changes to the codebase. +2. Stage your changes by using git add for the files you modified. +3. Commit your changes using git commit. The pre-commit hooks will run automatically at this point. +4. If all hooks run successfully, you can write your commit message, and your changes will be successfully committed. +5. If any hook fails, your commit will not be successful. Please read and follow the error messages provided, make the necessary changes, and then re-run git add and git commit. -git add -git commit -``` +### Your Contribution +Your contributions are valuable to us, and by following these guidelines, you help maintain code consistency and quality in our project. We appreciate your dedication to ensuring high-quality code. If you have questions or need assistance, feel free to reach out to us. Thank you for being part of our open-source community! -[git]: https://git-scm.com/book/en/v2/Customizing-Git-Git-Hooks -[flake8]: https://github.com/PyCQA/flake8 -[PEP8]: https://www.python.org/dev/peps/pep-0008/ -[black]: https://github.com/psf/black -[hooks]: https://github.com/pre-commit/pre-commit-hooks -[pre-commit]: https://github.com/pre-commit/pre-commit -[isort]: https://github.com/PyCQA/isort From c970df512b189b147e7fe6d45a7e8eb8609b9415 Mon Sep 17 00:00:00 2001 From: Tiance Wang Date: Mon, 30 Oct 2023 12:09:39 +0800 Subject: [PATCH 094/113] New recipe: tiny_transducer_ctc (#848) * initial commit * update readme * Update README.md * change bool to str2bool for arg parser * run validation only at the end of epoch * black format * black format --- .../ASR/tiny_transducer_ctc/README.md | 184 +++ .../ASR/tiny_transducer_ctc/asr_datamodule.py | 454 ++++++ .../ASR/tiny_transducer_ctc/beam_search.py | 1 + .../ASR/tiny_transducer_ctc/ctc_decode.py | 770 ++++++++++ .../ASR/tiny_transducer_ctc/decode.py | 717 ++++++++++ .../ASR/tiny_transducer_ctc/decoder.py | 1 + .../ASR/tiny_transducer_ctc/encoder.py | 379 +++++ .../tiny_transducer_ctc/encoder_interface.py | 1 + .../ASR/tiny_transducer_ctc/export.py | 316 +++++ .../ASR/tiny_transducer_ctc/jit_pretrained.py | 271 ++++ .../tiny_transducer_ctc/jit_pretrained_ctc.py | 426 ++++++ .../ASR/tiny_transducer_ctc/joiner.py | 1 + .../ASR/tiny_transducer_ctc/model.py | 1 + .../ASR/tiny_transducer_ctc/pretrained.py | 357 +++++ .../ASR/tiny_transducer_ctc/pretrained_ctc.py | 444 ++++++ .../ASR/tiny_transducer_ctc/scaling.py | 1 + .../ASR/tiny_transducer_ctc/train.py | 1251 +++++++++++++++++ 17 files changed, 5575 insertions(+) create mode 100644 egs/librispeech/ASR/tiny_transducer_ctc/README.md create mode 100644 egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py create mode 120000 egs/librispeech/ASR/tiny_transducer_ctc/beam_search.py create mode 100644 egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py create mode 100644 egs/librispeech/ASR/tiny_transducer_ctc/decode.py create mode 120000 egs/librispeech/ASR/tiny_transducer_ctc/decoder.py create mode 100644 egs/librispeech/ASR/tiny_transducer_ctc/encoder.py create mode 120000 egs/librispeech/ASR/tiny_transducer_ctc/encoder_interface.py create mode 100755 egs/librispeech/ASR/tiny_transducer_ctc/export.py create mode 100755 egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py create mode 100755 egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py create mode 120000 egs/librispeech/ASR/tiny_transducer_ctc/joiner.py create mode 120000 egs/librispeech/ASR/tiny_transducer_ctc/model.py create mode 100755 egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py create mode 100755 egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py create mode 120000 egs/librispeech/ASR/tiny_transducer_ctc/scaling.py create mode 100644 egs/librispeech/ASR/tiny_transducer_ctc/train.py diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/README.md b/egs/librispeech/ASR/tiny_transducer_ctc/README.md new file mode 100644 index 000000000..78dbc12c9 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/README.md @@ -0,0 +1,184 @@ +## Introduction + +This recipe is intended for streaming ASR on very low cost devices, with model parameters in the range of 1-2M. It uses a small convolutional net as the encoder. It is trained with combined transducer and CTC losses, and supports both phone and BPE lexicons. For phone lexicon, you can do transducer decoding using a method with LG, but the results were bad. + +The encoder consists of 2 subsampling layers followed by a stack of Conv1d-batchnorm-activation-causal_squeeze_excite blocks, with optional skip connections. To reduce latency (at the cost of slightly higher WER), half of the blocks use causal convolution. + +A few remarks & observations: + +1. Phone lexicon works better than BPE for CTC decoding (with HLG) but worse for transducer decoding. + +2. SpecAugment is not helpful for very small models as they tend to underfit rather than overfit. For the large model, a less aggressive SpecAugment (see asr_datamodule.py) improved the result a little. + +3. Squeeze-and-excitation worked like a charm! It reduces WER quite a bit with marginal increase of parameters and MAC ops. To make it causal I changed the global average pooling layer to a moving average filter, so only historical context is used. + +## Pretrained models + +You can find pretrained models, training logs, decoding logs, and decoding results at: + + +## Results on full libri + +I tried 3 different sizes of the encoder. The parameters are around 1M, 2M and 4M, respectively. For CTC decoding, whole-lattice-rescoring frequently causes OOM error so the result is not shown. + +### Small encoder + +The small encoder uses 10 layers of 1D convolution block with 256 channels, without skip connections. The encoder, decoder and joiner dim is 256. Algorithmic latency is 280ms. Multiply-add ops for the encoder is 22.0Mops. It is more applicable for ASR products with limited vocabulary (like a fixed set of phrases or short sentences). + +#### CTC decoding with phone lexicon +Total parameters: 1073392 + +Parameters for CTC decoding: 865816 + +| | test-clean | test-other | comment | +|-----------------|------------|------------|----------------------| +| 1best | 9.68 | 24.9 | --epoch 30 --avg 2 | +| nbest-rescoring | 8.2 | 22.7 | --epoch 30 --avg 2 | + +The training commands are: +```bash + +./tiny_transducer_ctc/train.py \ + --num-epochs 30 \ + --full-libri 1 \ + --max-duration 600 \ + --exp-dir tiny_transducer_ctc/exp_small_phone \ + --ctc-loss-scale 0.7 \ + --enable-spec-aug 0 \ + --lang-dir lang_phone \ + --encoder-dim 256 \ + --decoder-dim 256 \ + --joiner-dim 256 \ + --conv-layers 10 \ + --channels 256 \ + --skip-add 0 \ +``` + +#### Transducer decoding with BPE 500 lexicon +Total parameters: 1623264 + +Parameters for transducer decoding: 1237764 + +| | test-clean | test-other | comment | +|--------------------|------------|------------|----------------------| +| greedy_search | 14.47 | 32.03 | --epoch 30 --avg 1 | +| fast_beam_search | 13.38 | 29.61 | --epoch 30 --avg 1 | +|modified_beam_search| 13.02 | 29.32 | --epoch 30 --avg 1 | + +The training commands are: +```bash + +./tiny_transducer_ctc/train.py \ + --num-epochs 30 \ + --full-libri 1 \ + --max-duration 600 \ + --exp-dir tiny_transducer_ctc/exp_small_bpe \ + --ctc-loss-scale 0.2 \ + --enable-spec-aug 0 \ + --lang-dir lang_bpe_500 \ + --encoder-dim 256 \ + --decoder-dim 256 \ + --joiner-dim 256 \ + --conv-layers 10 \ + --channels 256 \ + --skip-add 0 \ +``` + +### Middle encoder + +The middle encoder uses 18 layers of 1D convolution block with 300 channels, with skip connections. The encoder, decoder and joiner dim is 256. Algorithmic latency is 440ms. Multiply-add ops for the encoder is 50.1Mops. Note that the nbest-rescoring result is better than the tdnn_lstm_ctc recipe with whole-lattice-rescoring. + +#### CTC decoding with phone lexicon +Total parameters: 2186242 + +Parameters for CTC decoding: 1978666 + +| | test-clean | test-other | comment | +|-----------------|------------|------------|----------------------| +| 1best | 7.48 | 18.94 | --epoch 30 --avg 1 | +| nbest-rescoring | 6.31 | 16.89 | --epoch 30 --avg 1 | + +The training commands are: +```bash + +./tiny_transducer_ctc/train.py \ + --num-epochs 30 \ + --full-libri 1 \ + --max-duration 600 \ + --exp-dir tiny_transducer_ctc/exp_middle_phone \ + --ctc-loss-scale 0.7 \ + --enable-spec-aug 0 \ + --lang-dir lang_phone \ + --encoder-dim 256 \ + --decoder-dim 256 \ + --joiner-dim 256 \ + --conv-layers 18 \ + --channels 300 \ + --skip-add 1 \ +``` + +#### Transducer decoding with BPE 500 lexicon +Total parameters: 2735794 + +Parameters for transducer decoding: 2350294 + +| | test-clean | test-other | comment | +|--------------------|------------|------------|----------------------| +| greedy_search | 10.26 | 25.13 | --epoch 30 --avg 2 | +| fast_beam_search | 9.69 | 23.58 | --epoch 30 --avg 2 | +|modified_beam_search| 9.43 | 23.53 | --epoch 30 --avg 2 | + +The training commands are: +```bash + +./tiny_transducer_ctc/train.py \ + --num-epochs 30 \ + --full-libri 1 \ + --max-duration 600 \ + --exp-dir tiny_transducer_ctc/exp_middle_bpe \ + --ctc-loss-scale 0.2 \ + --enable-spec-aug 0 \ + --lang-dir lang_bpe_500 \ + --encoder-dim 256 \ + --decoder-dim 256 \ + --joiner-dim 256 \ + --conv-layers 18 \ + --channels 300 \ + --skip-add 1 \ +``` + +### Large encoder + +The large encoder uses 18 layers of 1D convolution block with 400 channels, with skip connections. The encoder, decoder and joiner dim is 400. Algorithmic latency is 440ms. Multiply-add ops for the encoder is 88.8Mops. It is interesting to see how much the gap is if we simply scale down more complicated models like Zipformer or emformer. + + +#### Transducer decoding with BPE 500 lexicon +Total parameters: 4821330 + +Parameters for transducer decoding: 4219830 + +| | test-clean | test-other | comment | +|--------------------|------------|------------|----------------------| +| greedy_search | 8.29 | 21.11 | --epoch 30 --avg 1 | +| fast_beam_search | 7.91 | 20.1 | --epoch 30 --avg 1 | +|modified_beam_search| 7.74 | 19.89 | --epoch 30 --avg 1 | + + +The training commands are: +```bash + +./tiny_transducer_ctc/train.py \ + --num-epochs 30 \ + --full-libri 1 \ + --max-duration 600 \ + --exp-dir tiny_transducer_ctc/exp_large_bpe \ + --ctc-loss-scale 0.2 \ + --enable-spec-aug 1 \ + --lang-dir lang_bpe_500 \ + --encoder-dim 400 \ + --decoder-dim 400 \ + --joiner-dim 400 \ + --conv-layers 18 \ + --channels 400 \ + --skip-add 1 \ +``` diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py new file mode 100644 index 000000000..8facb6dba --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py @@ -0,0 +1,454 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=False, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=0, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_feature_masks=2, + features_mask_size=5, + num_frame_masks=10, + frames_mask_size=5, + p=0.5, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/beam_search.py b/egs/librispeech/ASR/tiny_transducer_ctc/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py new file mode 100644 index 000000000..402aeac0c --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py @@ -0,0 +1,770 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple +import pprint +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="tiny_transducer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_phone", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="1best", + help="""Decoding method. + Supported values are: + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (3) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (4) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.7, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "context_size": 2, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, _ = model.encoder(feature, feature_lens) + nnet_output = model.ctc_output(encoder_out) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="trunc", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="trunc", + ), + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ] + + # lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + # lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + # lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + lm_scale_list = [0.6, 0.7, 0.8, 0.9] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-hlg-scale-{params.hlg_scale}" + + if params.use_averaged_model: + params.suffix += "-uam" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(pprint.pformat(params, indent=2)) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + + if params.decoding_method == "ctc-decoding": + assert "lang_bpe" in str( + params.lang_dir + ), "ctc-decoding only supports BPE lexicons." + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + enc_param = sum([p.numel() for p in model.encoder.parameters()]) + ctc_param = sum([p.numel() for p in model.ctc_output.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + logging.info(f"Parameters for CTC decoding: {enc_param + ctc_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py new file mode 100644 index 000000000..6c2bf9ea1 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py @@ -0,0 +1,717 @@ +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple +import pprint +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="tiny_transducer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="fast_beam_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.1, + help=""" + Used only when --decoding_method is fast_beam_search_LG or + fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + if "lang_phone" in str(params.lang_dir): + assert params.decoding_method in ( + "fast_beam_search_LG", + "fast_beam_search_nbest_LG", + ), "For phone lexicon, use a decoding method with LG." + + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-uam" + + setup_logger(f"{params.res_dir}/log-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + if "lang_bpe" in str(params.lang_dir): + sp = spm.SentencePieceProcessor() + sp.load(str(params.lang_dir / "bpe.model")) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + else: + params.blank_id = lexicon.token_table.get("") + params.unk_id = lexicon.token_table.get("SPN") + params.vocab_size = max(lexicon.tokens) + 1 + sp = None + + logging.info(pprint.pformat(params, indent=2)) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if "LG" in params.decoding_method: + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + enc_param = sum([p.numel() for p in model.encoder.parameters()]) + dec_param = sum([p.numel() for p in model.decoder.parameters()]) + join_param = sum([p.numel() for p in model.joiner.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + logging.info( + f"Parameters for transducer decoding: {enc_param + dec_param + join_param}" + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decoder.py b/egs/librispeech/ASR/tiny_transducer_ctc/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py b/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py new file mode 100644 index 000000000..4c7fca4fc --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py @@ -0,0 +1,379 @@ +#!/usr/bin/env python3 +# Copyright (c) 2022 Spacetouch Inc. (author: Tiance Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +import torch.nn.functional as F +from encoder_interface import EncoderInterface +from scaling import ActivationBalancer, DoubleSwish +from torch import Tensor, nn + + +class Conv1dNet(EncoderInterface): + """ + 1D Convolution network with causal squeeze and excitation + module and optional skip connections. + + Latency: 80ms + (conv_layers+1) // 2 * 40ms, assuming 10ms stride. + + Args: + output_dim (int): Number of output channels of the last layer. + input_dim (int): Number of input features + conv_layers (int): Number of convolution layers, + excluding the subsampling layers. + channels (int): Number of output channels for each layer, + except the last layer. + subsampling_factor (int): The subsampling factor for the model. + skip_add (bool): Whether to use skip connection for each convolution layer. + dscnn (bool): Whether to use depthwise-separated convolution. + activation (str): Activation function type. + """ + + def __init__( + self, + output_dim: int, + input_dim: int = 80, + conv_layers: int = 10, + channels: int = 256, + subsampling_factor: int = 4, + skip_add: bool = False, + dscnn: bool = True, + activation: str = "relu", + ) -> None: + super().__init__() + assert subsampling_factor == 4, "Only support subsampling = 4" + + self.conv_layers = conv_layers + self.skip_add = skip_add + # 80ms latency for subsample_layer + self.subsample_layer = nn.Sequential( + conv1d_bn_block( + input_dim, channels, 9, stride=2, activation=activation, dscnn=dscnn + ), + conv1d_bn_block( + channels, channels, 5, stride=2, activation=activation, dscnn=dscnn + ), + ) + + self.conv_blocks = nn.ModuleList() + cin = [channels] * conv_layers + cout = [channels] * (conv_layers - 1) + [output_dim] + + # Use causal and standard convolution alternatively + for ly in range(conv_layers): + self.conv_blocks.append( + nn.Sequential( + conv1d_bn_block( + cin[ly], + cout[ly], + 3, + activation=activation, + dscnn=dscnn, + causal=ly % 2, + ), + CausalSqueezeExcite1d(cout[ly], 16, 30), + ) + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) + x = self.subsample_layer(x) + for idx, layer in enumerate(self.conv_blocks): + if self.skip_add and 0 < idx < self.conv_layers - 1: + x = layer(x) + x + else: + x = layer(x) + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) + lengths = x_lens >> 2 + return x, lengths + + +def get_activation( + name: str, + channels: int, + channel_dim: int = -1, + min_val: int = 0, + max_val: int = 1, +) -> nn.Module: + """ + Get activation function from name in string. + + Args: + name: activation function name + channels: only used for PReLU, should be equal to x.shape[1]. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + e.g. for NCHW tensor, channel_dim = 1 + min_val: minimum value of hardtanh + max_val: maximum value of hardtanh + + Returns: + The activation function module + + """ + act_layer = nn.Identity() + name = name.lower() + if name == "prelu": + act_layer = nn.PReLU(channels) + elif name == "relu": + act_layer = nn.ReLU() + elif name == "relu6": + act_layer = nn.ReLU6() + elif name == "hardtanh": + act_layer = nn.Hardtanh(min_val, max_val) + elif name in ["swish", "silu"]: + act_layer = nn.SiLU() + elif name == "elu": + act_layer = nn.ELU() + elif name == "doubleswish": + act_layer = nn.Sequential( + ActivationBalancer(num_channels=channels, channel_dim=channel_dim), + DoubleSwish(), + ) + elif name == "": + act_layer = nn.Identity() + else: + raise Exception(f"Unknown activation function: {name}") + + return act_layer + + +class CausalSqueezeExcite1d(nn.Module): + """ + Causal squeeze and excitation module with input and output shape + (batch, channels, time). The global average pooling in the original + SE module is replaced by a causal filter, so + the layer does not introduce any algorithmic latency. + + Args: + channels (int): Number of channels + reduction (int): channel reduction rate + context_window (int): Context window size for the moving average operation. + For EMA, the smoothing factor is 1 / context_window. + """ + + def __init__( + self, + channels: int, + reduction: int = 16, + context_window: int = 10, + ) -> None: + super(CausalSqueezeExcite1d, self).__init__() + + assert channels >= reduction + + self.context_window = context_window + c_squeeze = channels // reduction + self.linear1 = nn.Linear(channels, c_squeeze, bias=True) + self.act1 = nn.ReLU(inplace=True) + self.linear2 = nn.Linear(c_squeeze, channels, bias=True) + self.act2 = nn.Sigmoid() + + # EMA worked better than MA empirically + # self.avg_filter = self.moving_avg + self.avg_filter = self.exponential_moving_avg + self.ema_matrix = torch.tensor([0]) + self.ema_matrix_size = 0 + + def _precompute_ema_matrix(self, N: int, device: torch.device): + a = 1.0 / self.context_window # smoothing factor + w = [[(1 - a) ** k * a for k in range(n, n - N, -1)] for n in range(N)] + w = torch.tensor(w).to(device).tril() + w[:, 0] *= self.context_window + self.ema_matrix = w.T + self.ema_matrix_size = N + + def exponential_moving_avg(self, x: Tensor) -> Tensor: + """ + Exponential moving average filter, which is calculated as: + y[t] = (1-a) * y[t-1] + a * x[t] + where a = 1 / self.context_window is the smoothing factor. + + For training, the iterative version is too slow. A better way is + to expand y[t] as a function of x[0..t] only and use matrix-vector multiplication. + The weight matrix can be precomputed if the smoothing factor is fixed. + """ + if self.training: + # use matrix version to speed up training + N = x.shape[-1] + if N > self.ema_matrix_size: + self._precompute_ema_matrix(N, x.device) + y = torch.matmul(x, self.ema_matrix[:N, :N]) + else: + # use iterative version to save memory + a = 1.0 / self.context_window + y = torch.empty_like(x) + y[:, :, 0] = x[:, :, 0] + for t in range(1, y.shape[-1]): + y[:, :, t] = (1 - a) * y[:, :, t - 1] + a * x[:, :, t] + return y + + def moving_avg(self, x: Tensor) -> Tensor: + """ + Simple moving average with context_window as window size. + """ + y = torch.empty_like(x) + k = min(x.shape[2], self.context_window) + w = [[1 / n] * n + [0] * (k - n - 1) for n in range(1, k)] + w = torch.tensor(w, device=x.device) + y[:, :, : k - 1] = torch.matmul(x[:, :, : k - 1], w.T) + y[:, :, k - 1 :] = F.avg_pool1d(x, k, 1) + return y + + def forward(self, x: Tensor) -> Tensor: + + assert len(x.shape) == 3, "Input is not a 3D tensor!" + y = self.exponential_moving_avg(x) + y = y.permute(0, 2, 1) # make channel last for squeeze op + y = self.act1(self.linear1(y)) + y = self.act2(self.linear2(y)) + y = y.permute(0, 2, 1) # back to original shape + y = x * y + return y + + +def conv1d_bn_block( + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + activation: str = "relu", + dscnn: bool = False, + causal: bool = False, +) -> nn.Sequential: + """ + Conv1d - batchnorm - activation block. + If kernel size is even, output length = input length + 1. + Otherwise, output and input lengths are equal. + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + kernel_size (int): kernel size + stride (int): convolution stride + dilation (int): convolution dilation rate + dscnn (bool): Use depthwise separated convolution. + causal (bool): Use causal convolution + activation (str): Activation function type. + + """ + if dscnn: + return nn.Sequential( + CausalConv1d( + in_channels, + in_channels, + kernel_size, + stride=stride, + dilation=dilation, + groups=in_channels, + bias=False, + ) + if causal + else nn.Conv1d( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=(kernel_size // 2) * dilation, + dilation=dilation, + groups=in_channels, + bias=False, + ), + nn.BatchNorm1d(in_channels), + get_activation(activation, in_channels), + nn.Conv1d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm1d(out_channels), + get_activation(activation, out_channels), + ) + else: + return nn.Sequential( + CausalConv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + bias=False, + ) + if causal + else nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=(kernel_size // 2) * dilation, + dilation=dilation, + bias=False, + ), + nn.BatchNorm1d(out_channels), + get_activation(activation, out_channels), + ) + + +class CausalConv1d(nn.Module): + """ + Causal convolution with padding automatically chosen to match input/output length. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + ) -> None: + super(CausalConv1d, self).__init__() + assert kernel_size > 2 + + self.padding = dilation * (kernel_size - 1) + self.stride = stride + + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride, + self.padding, + dilation, + groups, + bias=bias, + ) + + def forward(self, x: Tensor) -> Tensor: + return self.conv(x)[:, :, : -self.padding // self.stride] diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/encoder_interface.py b/egs/librispeech/ASR/tiny_transducer_ctc/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/export.py b/egs/librispeech/ASR/tiny_transducer_ctc/export.py new file mode 100755 index 000000000..4117f7244 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/export.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./tiny_transducer_ctc/export.py \ + --exp-dir ./tiny_transducer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --epoch 30 \ + --avg 2 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./tiny_transducer_ctc/export.py \ + --exp-dir ./tiny_transducer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --epoch 30 \ + --avg 2 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `tiny_transducer_ctc/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./tiny_transducer_ctc/decode.py \ + --exp-dir ./tiny_transducer_ctc/exp \ + --epoch 9999 \ + --use-averaged-model 0 + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --lang-dir data/lang_bpe_500 \ + +Check ./pretrained.py for its usage. + +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import UniqLexicon +from icefall.utils import str2bool +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="tiny_transducer_ctc/exp_4m_bpe500_halfdelay_specaug", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if "lang_bpe" in str(params.lang_dir): + sp = spm.SentencePieceProcessor() + sp.load(params.lang_dir + "/bpe.model") + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + else: + assert "lang_phone" in str(params.lang_dir) + phone_lexicon = UniqLexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(phone_lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + logging.info("Using torch.jit.script()") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py new file mode 100755 index 000000000..3888d3544 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./tiny_transducer_ctc/export.py \ + --exp-dir ./tiny_transducer_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./tiny_transducer_ctc/jit_pretrained.py \ + --nn-model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py new file mode 100755 index 000000000..6f2cbaabd --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./tiny_transducer_ctc/export.py \ + --exp-dir ./tiny_transducer_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +(1) ctc-decoding +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(3) nbest-rescoring +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from ctc_decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the torchscript model.", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.model_filename) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + nnet_output = model.ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i] // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/joiner.py b/egs/librispeech/ASR/tiny_transducer_ctc/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/model.py b/egs/librispeech/ASR/tiny_transducer_ctc/model.py new file mode 120000 index 000000000..545af927f --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_ctc/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py new file mode 100755 index 000000000..981039b8f --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py @@ -0,0 +1,357 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./tiny_transducer_ctc/export.py \ + --exp-dir ./tiny_transducer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./tiny_transducer_ctc/pretrained.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --lang-dir data/lang_bpe_500 \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./tiny_transducer_ctc/pretrained.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --lang-dir data/lang_bpe_500 \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./tiny_transducer_ctc/pretrained.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --lang-dir data/lang_bpe_500 \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./tiny_transducer_ctc/pretrained.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --lang-dir data/lang_bpe_500 \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./tiny_transducer_ctc/exp/epoch-xx.pt`. + +Note: ./tiny_transducer_ctc/exp/pretrained.pt is generated by +./tiny_transducer_ctc/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.lang_dir + "/bpe.model") + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py new file mode 100755 index 000000000..a06d6d684 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py @@ -0,0 +1,444 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./tiny_transducer_ctc/export.py \ + --exp-dir ./tiny_transducer_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) ctc-decoding +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --lang-dir data/lang_bpe_500 \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) nbest-rescoring +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./tiny_transducer_ctc/jit_pretrained_ctc.py \ + --checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from ctc_decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + params.vocab_size = params.num_classes + params.blank_id = 0 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + nnet_output = model.ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i] // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/scaling.py b/egs/librispeech/ASR/tiny_transducer_ctc/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py new file mode 100644 index 000000000..307ad72aa --- /dev/null +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -0,0 +1,1251 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + +cd egs/librispeech/ASR/ +./prepare.sh + +Run below if you want to use the phone lexicon instead of BPE: +python local/generate_unique_lexicon.py --lang-dir data/lang_phone + +""" +import argparse +import copy +import logging +import pprint +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from encoder import Conv1dNet +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import AdamW +from torch.optim.lr_scheduler import StepLR +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics, is_module_available +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import UniqLexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + + parser.add_argument( + "--encoder-dim", + type=int, + default=256, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=256, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=256, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--conv-layers", + type=int, + default=10, + help="""Number of convolution layers for the encoder. + """, + ) + + parser.add_argument( + "--channels", + type=int, + default=256, + help="""Number of channels for the encoder. + """, + ) + + parser.add_argument( + "--skip-add", + type=str2bool, + default=False, + help="""Use skip connection in the encoder. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="tiny_transducer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="""Weight for CTC loss, between 0 and 1. + When set to 0, only transducer loss is used. + When set to 1, only CTC loss is used.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=5, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 500, + "reset_interval": 200, + "warm_step": 5000, + "beam_size": 10, + "use_double_scores": True, + "env_info": get_env_info(), + "feature_dim": 80, + "subsampling_factor": 4, + "use_dscnn": True, + "activation": "doubleswish", + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + + encoder = Conv1dNet( + output_dim=params.encoder_dim, + input_dim=params.feature_dim, + conv_layers=params.conv_layers, + channels=params.channels, + subsampling_factor=params.subsampling_factor, + skip_add=params.skip_add, + dscnn=params.use_dscnn, + activation=params.activation, + ) + + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> Transducer: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + if is_module_available("thop"): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + # Assuming 10ms stride, 1000 frames is about 10 seconds. + x = torch.zeros((1, 1000, params.feature_dim)).to(device) + x_lens = torch.Tensor([1000]).int().to(device) + from thop import clever_format, profile + + m = copy.deepcopy(encoder) + m = m.to(device) + ops, _ = clever_format(profile(m, (x, x_lens), verbose=False)) + logging.info(f"Encoder MAC ops for 10 seconds of audio is {ops}") + else: + logging.info("You can install thop to calculate the number of ops.") + logging.info("Command: pip install thop") + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + phone_lexicon: UniqLexicon, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + if sp is not None: + token_ids = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(token_ids).to(device) + else: + y = phone_lexicon.texts_to_token_ids(texts).to(device) + token_ids = y.tolist() + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_output = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + # Compute ctc loss + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + supervision_segments, token_ids = encode_supervisions( + supervisions, + subsampling_factor=params.subsampling_factor, + token_ids=token_ids, + ) + + decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction="sum", + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + assert 0 <= params.ctc_loss_scale <= 1, "ctc_loss_scale must be between 0 and 1" + loss = params.ctc_loss_scale * ctc_loss + (1 - params.ctc_loss_scale) * loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + phone_lexicon: UniqLexicon, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + phone_lexicon=phone_lexicon, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + phone_lexicon: UniqLexicon, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + phone_lexicon=phone_lexicon, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + # scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch( + batch, params=params, sp=sp, phone_lexicon=phone_lexicon + ) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + phone_lexicon=phone_lexicon, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + if "lang_bpe" in str(params.lang_dir): + sp = spm.SentencePieceProcessor() + sp.load(params.lang_dir + "/bpe.model") + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + phone_lexicon = None + else: + assert "lang_phone" in str(params.lang_dir) + phone_lexicon = UniqLexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(phone_lexicon.tokens) + 1 + sp = None + + logging.info(pprint.pformat(params, indent=2)) + + logging.info("About to create model") + model = get_transducer_model(params) + + if rank == 0: + num_param = sum([p.numel() for p in model.parameters()]) + enc_param = sum([p.numel() for p in model.encoder.parameters()]) + dec_param = sum([p.numel() for p in model.decoder.parameters()]) + join_param = sum([p.numel() for p in model.joiner.parameters()]) + ctc_param = sum([p.numel() for p in model.ctc_output.parameters()]) + + logging.info(f"Number of model parameters: {num_param}") + logging.info(f"Number of encoder parameters: {enc_param}") + logging.info(f"Number of decoder parameters: {dec_param}") + logging.info(f"Number of joiner parameters: {join_param}") + logging.info(f"Number of ctc parameters: {ctc_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = AdamW( + model.parameters(), + lr=params.initial_lr, + weight_decay=5e-4, + ) + + scheduler = StepLR(optimizer, step_size=2, gamma=0.8) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # phone_lexicon=phone_lexicon, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + # scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + phone_lexicon=phone_lexicon, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, + phone_lexicon: UniqLexicon, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + if sp is not None: + y = sp.encode(supervisions["text"], out_type=int) + else: + y = phone_lexicon.texts_to_token_ids(supervisions["text"]) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + phone_lexicon: UniqLexicon, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + phone_lexicon=phone_lexicon, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch( + batch, params=params, sp=sp, phone_lexicon=phone_lexicon + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From 23913f6afdea59caf703e3ac715852810cd246ad Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 31 Oct 2023 10:28:20 +0800 Subject: [PATCH 095/113] Minor refinements for some stale but recently merged PRs (#1354) * incorporate https://github.com/k2-fsa/icefall/pull/1269 * incorporate https://github.com/k2-fsa/icefall/pull/1301 * black formatted * incorporate https://github.com/k2-fsa/icefall/pull/1162 * black formatted --- egs/aishell/ASR/zipformer/train.py | 2 +- .../ASR/zipformer/asr_datamodule.py | 2 +- egs/gigaspeech/ASR/zipformer/train.py | 2 +- .../ASR/zipformer_prompt_asr/optim.py | 12 +++---- .../zipformer_prompt_asr/train_baseline.py | 2 +- .../train_bert_encoder.py | 2 +- .../ASR/tiny_transducer_ctc/asr_datamodule.py | 8 ++--- .../ASR/tiny_transducer_ctc/ctc_decode.py | 3 +- .../ASR/tiny_transducer_ctc/encoder.py | 1 - .../ASR/tiny_transducer_ctc/export.py | 31 ++++++----------- .../ASR/tiny_transducer_ctc/train.py | 4 +-- egs/librispeech/ASR/zipformer_ctc/export.py | 21 +++++------- egs/librispeech/ASR/zipformer_ctc/train.py | 2 +- icefall/diagnostics.py | 33 ++++++++++++------- 14 files changed, 57 insertions(+), 68 deletions(-) diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index 7e7b02829..d381649e4 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -1128,7 +1128,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py index c4472ed23..6adfdbfbb 100644 --- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -233,7 +233,7 @@ class GigaSpeechAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index d8ff4fecc..d93cc221c 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -1164,7 +1164,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py index a767761eb..159e363c7 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py @@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer): yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for (stacked_params, _state, _names), batch in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer): size_update_period=4, clipping_update_period=100, ): - defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -429,7 +426,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: + for p, state, param_names in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -514,7 +511,7 @@ class ScaledAdam(BatchedOptimizer): from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: + for p, state, batch_param_names in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars @@ -530,7 +527,6 @@ class ScaledAdam(BatchedOptimizer): for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): - proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) @@ -1106,7 +1102,7 @@ def _test_scaled_adam(hidden_dim: int): # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 + # 512 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py index 32302602c..c8b20d021 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py @@ -1194,7 +1194,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py index e253d1118..9822b99c1 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py @@ -1565,7 +1565,7 @@ def run(rank, world_size, args): if params.print_diagnostics: args.max_duration = 100 opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py index 8facb6dba..3acd22ae4 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py @@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -225,7 +225,7 @@ class LibriSpeechAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -307,8 +307,8 @@ class LibriSpeechAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py index 402aeac0c..cda03b56e 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py @@ -22,10 +22,11 @@ import argparse import logging import math +import pprint from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple -import pprint + import k2 import sentencepiece as spm import torch diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py b/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py index 4c7fca4fc..afdd00293 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py @@ -253,7 +253,6 @@ class CausalSqueezeExcite1d(nn.Module): return y def forward(self, x: Tensor) -> Tensor: - assert len(x.shape) == 3, "Input is not a 3D tensor!" y = self.exponential_moving_avg(x) y = y.permute(0, 2, 1) # make channel last for squeeze op diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/export.py b/egs/librispeech/ASR/tiny_transducer_ctc/export.py index 4117f7244..334dd011e 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/export.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/export.py @@ -76,17 +76,17 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch +from train import add_model_arguments, get_params, get_transducer_model + from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) -from icefall.lexicon import UniqLexicon -from icefall.utils import str2bool -from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -143,13 +143,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -189,17 +186,9 @@ def main(): logging.info(f"device: {device}") - if "lang_bpe" in str(params.lang_dir): - sp = spm.SentencePieceProcessor() - sp.load(params.lang_dir + "/bpe.model") - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - else: - assert "lang_phone" in str(params.lang_dir) - phone_lexicon = UniqLexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(phone_lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py index 307ad72aa..8920764cd 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -89,7 +89,6 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( "--encoder-dim", type=int, @@ -405,7 +404,6 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conv1dNet( output_dim=params.encoder_dim, input_dim=params.feature_dim, @@ -1043,7 +1041,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/zipformer_ctc/export.py b/egs/librispeech/ASR/zipformer_ctc/export.py index 0ff50f128..4c46aea2c 100755 --- a/egs/librispeech/ASR/zipformer_ctc/export.py +++ b/egs/librispeech/ASR/zipformer_ctc/export.py @@ -23,6 +23,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_ctc_model, get_params @@ -33,8 +34,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -90,11 +90,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""It contains language related input files such as "lexicon.txt" - """, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -113,17 +112,15 @@ def get_parser(): def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) - logging.info(params) + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - params.vocab_size = num_classes + logging.info(params) device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py index f40344357..60990456d 100755 --- a/egs/librispeech/ASR/zipformer_ctc/train.py +++ b/egs/librispeech/ASR/zipformer_ctc/train.py @@ -947,7 +947,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index ebf61784e..65b6f67b0 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -244,16 +244,14 @@ class TensorDiagnostic(object): if stats_type == "eigs": try: - if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'): + if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"): eigs, _ = torch.linalg.eigh(stats) else: eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() except: # noqa - print( - "Error getting eigenvalues, trying another method." - ) - if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'): + print("Error getting eigenvalues, trying another method.") + if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"): eigs, _ = torch.linalg.eig(stats) eigs = eigs.abs() else: @@ -579,10 +577,15 @@ def attach_diagnostics( ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, - class_name=get_class_name(_module)) - + if isinstance(o, Tensor) and o.dtype in ( + torch.float32, + torch.float16, + torch.float64, + ): + _model_diagnostic[f"{_name}.output[{i}]"].accumulate( + o, class_name=get_class_name(_module) + ) + def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] @@ -596,9 +599,15 @@ def attach_diagnostics( ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, - class_name=get_class_name(_module)) + if isinstance(o, Tensor) and o.dtype in ( + torch.float32, + torch.float16, + torch.float64, + ): + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate( + o, class_name=get_class_name(_module) + ) + module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) From 9e5a5d7839aa3052e46dcf25b239a37449f8cd5e Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 2 Nov 2023 16:10:08 +0800 Subject: [PATCH 096/113] Incorporate some latest changes to `optim.py` (#1359) * init commit * black formatted * isort formatted --- egs/librispeech/ASR/zipformer/optim.py | 171 +++++++++++++++++-------- 1 file changed, 121 insertions(+), 50 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 8ee2b0eb4..a663db708 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -22,7 +22,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch from lhotse.utils import fix_random_seed -from torch import Tensor +from torch import Tensor, nn from torch.optim import Optimizer @@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer): yield tuples # <-- calling code will do the actual optimization here! - for (stacked_params, _state, _names), batch in zip(tuples, batches): + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,6 +181,7 @@ class ScaledAdam(BatchedOptimizer): size_update_period=4, clipping_update_period=100, ): + defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -326,7 +327,9 @@ class ScaledAdam(BatchedOptimizer): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): + with self.batched_params(group["params"], group_params_names) as batches: + # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -423,16 +426,19 @@ class ScaledAdam(BatchedOptimizer): # parameters' state won't have been initialized yet. return 1.0 clipping_update_period = group["clipping_update_period"] + scalar_lr_scale = group["scalar_lr_scale"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for p, state, param_names in tuples: + for (p, state, param_names) in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( "ScaledAdam optimizer does not support sparse gradients" ) if p.numel() == p.shape[0]: # a batch of scalars - tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] + tot_sumsq += (grad**2).sum() * ( + scalar_lr_scale**2 + ) # sum() to change shape [1] to [] else: tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() @@ -443,64 +449,72 @@ class ScaledAdam(BatchedOptimizer): ) first_state["model_norms"][step % clipping_update_period] = tot_norm - if step % clipping_update_period == 0: + irregular_estimate_steps = [ + i for i in [10, 20, 40] if i < clipping_update_period + ] + if step % clipping_update_period == 0 or step in irregular_estimate_steps: # Print some stats. # We don't reach here if step == 0 because we would have returned # above. sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + if step in irregular_estimate_steps: + sorted_norms = sorted_norms[-step:] + num_norms = sorted_norms.numel() quartiles = [] for n in range(0, 5): - index = min( - clipping_update_period - 1, (clipping_update_period // 4) * n - ) + index = min(num_norms - 1, (num_norms // 4) * n) quartiles.append(sorted_norms[index].item()) median = quartiles[2] threshold = clipping_scale * median + if step in irregular_estimate_steps: + # use larger thresholds on first few steps of estimating threshold, + # as norm may be changing rapidly. + threshold = threshold * 2.0 first_state["model_norm_threshold"] = threshold percent_clipped = ( - first_state["num_clipped"] * 100.0 / clipping_update_period + first_state["num_clipped"] * 100.0 / num_norms if "num_clipped" in first_state else 0.0 ) first_state["num_clipped"] = 0 quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.info( + logging.warn( f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" ) - if step < clipping_update_period: - return 1.0 # We have not yet estimated a norm to clip to. - else: - try: - model_norm_threshold = first_state["model_norm_threshold"] - except KeyError: - logging.info( - "Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?" - ) - return 1.0 - ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) - if ans < 1.0: - first_state["num_clipped"] += 1 - if ans < 0.1: - logging.warn( - f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" - ) - if self.show_dominant_parameters: - assert p.shape[0] == len(param_names) - self._show_gradient_dominating_parameter(tuples, tot_sumsq) - if ans != ans: # e.g. ans is nan - ans = 0.0 - if ans == 0.0: - for p, state, param_names in tuples: - p.grad.zero_() # get rid of infinity() + try: + model_norm_threshold = first_state["model_norm_threshold"] + except KeyError: + return 1.0 # threshold has not yet been set. - return ans + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans != ans: # e.g. ans is nan + ans = 0.0 + if ans < 1.0: + first_state["num_clipped"] += 1 + if ans < 0.1: + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter( + tuples, tot_sumsq, group["scalar_lr_scale"] + ) + + if ans == 0.0: + for (p, state, param_names) in tuples: + p.grad.zero_() # get rid of infinity() + + return ans def _show_gradient_dominating_parameter( - self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + self, + tuples: List[Tuple[Tensor, dict, List[str]]], + tot_sumsq: Tensor, + scalar_lr_scale: float, ): """ Show information of parameter which dominates tot_sumsq. @@ -516,29 +530,30 @@ class ScaledAdam(BatchedOptimizer): from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for p, state, batch_param_names in tuples: + for (p, state, batch_param_names) in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars - batch_sumsq_orig = batch_grad**2 # Dummy values used by following `zip` statement. - batch_rms_orig = torch.ones(p.shape[0]) + batch_rms_orig = torch.full( + p.shape, scalar_lr_scale, device=batch_grad.device + ) else: batch_rms_orig = state["param_rms"] - batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2 + if batch_grad.ndim > 1: + # need to guard it with if-statement because sum() sums over + # all dims if dim == (). + batch_sumsq_orig = batch_sumsq_orig.sum( dim=list(range(1, batch_grad.ndim)) ) - for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): + proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) - assert torch.isclose( - sum([value[0] for value in all_sumsq_orig.values()]).cpu(), - torch.tensor(1.0), - ) sorted_by_proportion = { k: v for k, v in sorted( @@ -552,7 +567,7 @@ class ScaledAdam(BatchedOptimizer): dominant_rms, dominant_grad, ) = sorted_by_proportion[dominant_param_name] - logging.info( + logging.warn( f"Parameter dominating tot_sumsq {dominant_param_name}" f" with proportion {dominant_proportion:.2f}," f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" @@ -826,7 +841,7 @@ class LRScheduler(object): def print_lr(self, is_verbose, group, lr): """Display the current learning rate.""" if is_verbose: - logging.info( + logging.warn( f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" f" of group {group} to {lr:.4e}." ) @@ -841,8 +856,14 @@ class Eden(LRScheduler): where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches and then stays constant at 1. + If you don't have the concept of epochs, or one epoch takes a very long time, + you can replace the notion of 'epoch' with some measure of the amount of data + processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to + some measure representing "quite a lot of data": say, one fifth or one third + of an entire training run, but it doesn't matter much. You could also use + Eden2 which has only the notion of batches. - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam Args: optimizer: the optimizer to change the learning rates on @@ -888,6 +909,56 @@ class Eden(LRScheduler): return [x * factor * warmup_factor for x in self.base_lrs] +class Eden2(LRScheduler): + """ + Eden2 scheduler, simpler than Eden because it does not use the notion of epoch, + only batches. + + The basic formula (before warmup) is: + lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup + + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. + + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + warmup_start: float = 0.5, + verbose: bool = False, + ): + super().__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.warmup_batches = warmup_batches + + assert 0.0 <= warmup_start <= 1.0, warmup_start + self.warmup_start = warmup_start + + def get_lr(self): + factor = ( + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + ) ** -0.5 + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else self.warmup_start + + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + # else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ) + + return [x * factor * warmup_factor for x in self.base_lrs] + + def _test_eden(): m = torch.nn.Linear(100, 100) optim = ScaledAdam(m.parameters(), lr=0.03) From c3bbb32f9ec6402f20582020eed64b159c55796f Mon Sep 17 00:00:00 2001 From: wnywbyt <45236066+wnywbyt@users.noreply.github.com> Date: Thu, 2 Nov 2023 20:45:30 +0800 Subject: [PATCH 097/113] Update the parameter 'vocab-size' (#1364) Co-authored-by: wdq --- egs/wenetspeech/ASR/prepare.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index f7eb9f0d0..b0525de60 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -362,6 +362,6 @@ if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then --exp-dir rnnlm_char/exp \ --lm-data data/lm_char/sorted_lm_data.pt \ --lm-data-valid data/lm_char/sorted_lm_data-valid.pt \ - --vocab-size 4336 \ + --vocab-size 5537 \ --master-port 12340 fi \ No newline at end of file From 231bbcd2b638826a94cf019fa31ae8683d3552ee Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 3 Nov 2023 12:06:29 +0800 Subject: [PATCH 098/113] Update optim.py (#1366) --- egs/librispeech/ASR/zipformer/optim.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index a663db708..714d8db9a 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -466,6 +466,8 @@ class ScaledAdam(BatchedOptimizer): quartiles.append(sorted_norms[index].item()) median = quartiles[2] + if median - median != 0: + raise RuntimeError("Too many grads were not finite") threshold = clipping_scale * median if step in irregular_estimate_steps: # use larger thresholds on first few steps of estimating threshold, From 1b2e99d374cbbc527bf8c9239d616497249ccb1d Mon Sep 17 00:00:00 2001 From: lishaojie <95117087+manbaaaa@users.noreply.github.com> Date: Thu, 9 Nov 2023 22:07:28 +0800 Subject: [PATCH 099/113] add the pruned_transducer_stateless7_streaming recipe for commonvoice (#1018) * add the pruned_transducer_stateless7_streaming recipe for commonvoice * fix the symlinks * Update RESULTS.md --- egs/commonvoice/ASR/RESULTS.md | 25 + egs/commonvoice/ASR/local/compile_hlg.py | 1 + egs/commonvoice/ASR/local/compile_lg.py | 1 + .../compute_fbank_commonvoice_dev_test.py | 4 +- .../ASR/local/preprocess_commonvoice.py | 10 +- egs/commonvoice/ASR/prepare.sh | 64 +- .../README.md | 9 + .../beam_search.py | 1 + .../commonvoice_fr.py | 422 ++++++ .../decode.py | 810 ++++++++++ .../decode_stream.py | 1 + .../decoder.py | 1 + .../encoder_interface.py | 1 + .../export-for-ncnn-zh.py | 1 + .../export-for-ncnn.py | 1 + .../export-onnx.py | 1 + .../export.py | 1 + .../finetune.py | 1342 +++++++++++++++++ .../generate_model_from_checkpoint.py | 281 ++++ .../jit_pretrained.py | 1 + .../jit_trace_export.py | 1 + .../jit_trace_pretrained.py | 1 + .../joiner.py | 1 + .../model.py | 1 + .../onnx_check.py | 1 + .../onnx_model_wrapper.py | 1 + .../onnx_pretrained.py | 1 + .../optim.py | 1 + .../pretrained.py | 1 + .../scaling.py | 1 + .../scaling_converter.py | 1 + .../streaming-ncnn-decode.py | 1 + .../streaming_beam_search.py | 1 + .../streaming_decode.py | 612 ++++++++ .../test_model.py | 150 ++ .../train.py | 1256 +++++++++++++++ .../train2.py | 1257 +++++++++++++++ .../zipformer.py | 1 + .../zipformer2.py | 1 + icefall/shared/convert-k2-to-openfst.py | 103 +- icefall/shared/ngram_entropy_pruning.py | 631 +------- icefall/shared/parse_options.sh | 98 +- 42 files changed, 6260 insertions(+), 840 deletions(-) create mode 120000 egs/commonvoice/ASR/local/compile_hlg.py create mode 120000 egs/commonvoice/ASR/local/compile_lg.py create mode 100644 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py create mode 100644 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py mode change 100755 => 120000 icefall/shared/convert-k2-to-openfst.py mode change 100755 => 120000 icefall/shared/ngram_entropy_pruning.py mode change 100755 => 120000 icefall/shared/parse_options.sh diff --git a/egs/commonvoice/ASR/RESULTS.md b/egs/commonvoice/ASR/RESULTS.md index 751625371..2c158d91d 100644 --- a/egs/commonvoice/ASR/RESULTS.md +++ b/egs/commonvoice/ASR/RESULTS.md @@ -57,3 +57,28 @@ Pretrained model is available at The tensorboard log for training is available at + + +### Commonvoice (fr) BPE training results (Pruned Stateless Transducer 7_streaming) + +#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) + +See #1018 for more details. + +Number of model parameters: 70369391, i.e., 70.37 M + +The best WER for Common Voice French 12.0 (cv-corpus-12.0-2022-12-07/fr) is below: + +Results are: + +| decoding method | Test | +|----------------------|-------| +| greedy search | 9.95 | +| modified beam search | 9.57 | +| fast beam search | 9.67 | + +Note: This best result is trained on the full librispeech and gigaspeech, and then fine-tuned on the full commonvoice. + +Detailed experimental results and Pretrained model are available at + + diff --git a/egs/commonvoice/ASR/local/compile_hlg.py b/egs/commonvoice/ASR/local/compile_hlg.py new file mode 120000 index 000000000..471aa7fb4 --- /dev/null +++ b/egs/commonvoice/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/compile_lg.py b/egs/commonvoice/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/commonvoice/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py index c8f9b6ccb..a0b4d224c 100755 --- a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py +++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py @@ -56,8 +56,8 @@ def get_args(): def compute_fbank_commonvoice_dev_test(language: str): src_dir = Path(f"data/{language}/manifests") output_dir = Path(f"data/{language}/fbank") - num_workers = 42 - batch_duration = 600 + num_workers = 16 + batch_duration = 200 subsets = ("dev", "test") diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py index e60459765..5f6aa3ec0 100755 --- a/egs/commonvoice/ASR/local/preprocess_commonvoice.py +++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py @@ -43,9 +43,13 @@ def get_args(): return parser.parse_args() -def normalize_text(utt: str) -> str: +def normalize_text(utt: str, language: str) -> str: utt = re.sub(r"[{0}]+".format("-"), " ", utt) - return re.sub(r"[^a-zA-Z\s']", "", utt).upper() + utt = re.sub("’", "'", utt) + if language == "en": + return re.sub(r"[^a-zA-Z\s]", "", utt).upper() + if language == "fr": + return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper() def preprocess_commonvoice( @@ -94,7 +98,7 @@ def preprocess_commonvoice( for sup in m["supervisions"]: text = str(sup.text) orig_text = text - sup.text = normalize_text(sup.text) + sup.text = normalize_text(sup.text, language) text = str(sup.text) if len(orig_text) != len(text): logging.info( diff --git a/egs/commonvoice/ASR/prepare.sh b/egs/commonvoice/ASR/prepare.sh index 3946908c6..edac0e8e6 100755 --- a/egs/commonvoice/ASR/prepare.sh +++ b/egs/commonvoice/ASR/prepare.sh @@ -36,8 +36,8 @@ num_splits=1000 # - speech dl_dir=$PWD/download -release=cv-corpus-13.0-2023-03-09 -lang=en +release=cv-corpus-12.0-2022-12-07 +lang=fr . shared/parse_options.sh || exit 1 @@ -146,7 +146,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then ./local/compute_fbank_commonvoice_splits.py \ --num-workers $nj \ - --batch-duration 600 \ + --batch-duration 200 \ --start 0 \ --num-splits $num_splits \ --language $lang @@ -189,7 +189,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then sed -i 's/\t/ /g' $lang_dir/transcript_words.txt sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt fi - + if [ ! -f $lang_dir/words.txt ]; then cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ | sort -u | sed '/^$/d' > $lang_dir/words.txt @@ -216,14 +216,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then }' > $lang_dir/words || exit 1; mv $lang_dir/words $lang_dir/words.txt fi - + if [ ! -f $lang_dir/bpe.model ]; then ./local/train_bpe_model.py \ --lang-dir $lang_dir \ --vocab-size $vocab_size \ --transcript $lang_dir/transcript_words.txt fi - + if [ ! -f $lang_dir/L_disambig.pt ]; then ./local/prepare_lang_bpe.py --lang-dir $lang_dir @@ -250,3 +250,55 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then fi done fi + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Prepare G" + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/${lang}/lang_bpe_${vocab_size} + mkdir -p $lang_dir/lm + #3-gram used in building HLG, 4-gram used for LM rescoring + for ngram in 3 4; do + if [ ! -f $lang_dir/lm/${ngram}gram.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order ${ngram} \ + -text $lang_dir/transcript_words.txt \ + -lm $lang_dir/lm/${ngram}gram.arpa + fi + + if [ ! -f $lang_dir/lm/${ngram}gram.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=${ngram} \ + $lang_dir/lm/${ngram}gram.arpa > $lang_dir/lm/G_${ngram}_gram.fst.txt + fi + done + done +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Compile HLG" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/${lang}/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir + done +fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Compile LG" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/${lang}/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md new file mode 100644 index 000000000..991875aaa --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md @@ -0,0 +1,9 @@ +This recipe implements Streaming Zipformer-Transducer model. + +See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials. + +[./emformer.py](./emformer.py) and [./train.py](./train.py) +are basically the same as +[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py). +The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py) +is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn). diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py new file mode 120000 index 000000000..d7349b0a3 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py new file mode 100644 index 000000000..cafa4111d --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py @@ -0,0 +1,422 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class CommonVoiceAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--language", + type=str, + default="fr", + help="""Language of Common Voice""", + ) + group.add_argument( + "--cv-manifest-dir", + type=Path, + default=Path("data/fr/fbank"), + help="Path to directory with CommonVoice train/dev/test cuts.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz" + ) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy( + self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_test.jsonl.gz" + ) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py new file mode 100755 index 000000000..30f7c1e77 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -0,0 +1,810 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from commonvoice_fr import CommonVoiceAsrDataModule +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + feature_lens += 30 + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, 30), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + # errs_info = ( + # params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + # ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{key}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + commonvoice = CommonVoiceAsrDataModule(args) + + test_cuts = commonvoice.test_cuts() + + test_dl = commonvoice.test_dataloaders(test_cuts) + + test_sets = "test-cv" + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_sets, + results_dict=results_dict, + ) + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py new file mode 120000 index 000000000..ca8fed319 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py new file mode 120000 index 000000000..cb673b3eb --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py new file mode 120000 index 000000000..72e43c297 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py new file mode 120000 index 000000000..3b36924ef --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py new file mode 120000 index 000000000..57a0cd0a0 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py new file mode 120000 index 000000000..2acafdc61 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py new file mode 100755 index 000000000..3a10c5d81 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py @@ -0,0 +1,1342 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from commonvoice_fr import CommonVoiceAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--do-finetune", type=str2bool, default=False) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="""Embedding dimension in the 2 blocks of zipformer encoder + layers, comma separated + """, + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers,\ + comma separated; not the same as embedding dimension. + """, + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="""Unmasked dimensions in the encoders, relates to augmentation + during training. Must be <= each of encoder_dims. Empirically, less + than 256 seems to make performance worse. + """, + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="""Path to the BPE model. + This should be the bpe model of the original model + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.005, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000, + help="""Number of steps that affects how rapidly the learning rate + decreases. During fine-tuning, we set this very large so that the + learning rate slowly decays with number of batches. You may tune + its value by yourself. + """, + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100, + help="""Number of epochs that affects how rapidly the learning rate + decreases. During fine-tuning, we set this very large so that the + learning rate slowly decays with number of batches. You may tune + its value by yourself. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have + # different behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + else: + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + commonvoice = CommonVoiceAsrDataModule(args) + + train_cuts = commonvoice.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = commonvoice.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = commonvoice.dev_cuts() + valid_dl = commonvoice.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments( + parser + ) # you may replace this with your own dataset + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py new file mode 100755 index 000000000..3fd14aa47 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) use the averaged model with checkpoint exp_dir/epoch-xxx.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --epoch 28 \ + --avg 15 \ + --use-averaged-model True \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`. + +(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --iter 22000 \ + --avg 5 \ + --use-averaged-model True \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`. + +(3) use the original model with checkpoint exp_dir/epoch-xxx.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --epoch 28 \ + --avg 15 \ + --use-averaged-model False \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15.pt")`. + +(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --iter 22000 \ + --avg 5 \ + --use-averaged-model False \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5.pt")`. +""" + + +import argparse +from pathlib import Path +from typing import Dict, List + +import sentencepiece as spm +import torch +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model." + "If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + print("Script started") + + device = torch.device("cpu") + print(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + print("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + print(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + print(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + print( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = ( + params.exp_dir + / f"iter-{params.iter}-avg-{params.avg}-use-averaged-model.pt" + ) + torch.save({"model": model.state_dict()}, filename) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + print( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = ( + params.exp_dir + / f"epoch-{params.epoch}-avg-{params.avg}-use-averaged-model.pt" + ) + torch.save({"model": model.state_dict()}, filename) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py new file mode 120000 index 000000000..5d9c6ba00 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py new file mode 120000 index 000000000..457131699 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py new file mode 120000 index 000000000..2b8fa3cbb --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py new file mode 120000 index 000000000..e17d4f734 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py new file mode 120000 index 000000000..28bf7bb82 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py new file mode 120000 index 000000000..c8548d459 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py new file mode 120000 index 000000000..ae4d9bb04 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py new file mode 120000 index 000000000..9510b8fde --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py new file mode 120000 index 000000000..92c3904af --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py new file mode 120000 index 000000000..2adf271c1 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py new file mode 100755 index 000000000..dbe65d0a7 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -0,0 +1,612 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding_method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from commonvoice_fr import CommonVoiceAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + idx = 0 + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + audio: np.ndarray = cut.load_audio() + if audio.max() > 1 or audio.min() < -1: + audio = audio / max(abs(audio.max()), abs(audio.min())) + print(audio) + print(audio.max()) + print(audio.min()) + print(cut) + idx += 1 + print(idx) + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + commonvoice = CommonVoiceAsrDataModule(args) + test_cuts = commonvoice.test_cuts() + test_sets = "test-cv" + + results_dict = decode_dataset( + cuts=test_cuts, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_sets, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py new file mode 100755 index 000000000..5400df804 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless7_streaming/test_model.py +""" + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + # Test jit script + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + print("Using torch.jit.script") + model = torch.jit.script(model) + + +def test_model_jit_trace(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + convert_scaled_to_non_scaled(model, inplace=True) + + # Test encoder + def _test_encoder(): + encoder = model.encoder + assert encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + encoder.decode_chunk_size, + params.decode_chunk_len, + ) + T = params.decode_chunk_len + 7 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder.get_init_state(device=x.device) + encoder.__class__.forward = encoder.__class__.streaming_forward + traced_encoder = torch.jit.trace(encoder, (x, x_lens, states)) + + states1 = encoder.get_init_state(device=x.device) + states2 = traced_encoder.get_init_state(device=x.device) + for i in range(5): + x = torch.randn(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + y1, _, states1 = encoder.streaming_forward(x, x_lens, states1) + y2, _, states2 = traced_encoder(x, x_lens, states2) + assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean()) + + # Test decoder + def _test_decoder(): + decoder = model.decoder + y = torch.zeros(10, decoder.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_decoder = torch.jit.trace(decoder, (y, need_pad)) + d1 = decoder(y, need_pad) + d2 = traced_decoder(y, need_pad) + assert torch.equal(d1, d2), (d1 - d2).abs().mean() + + # Test joiner + def _test_joiner(): + joiner = model.joiner + encoder_out_dim = joiner.encoder_proj.weight.shape[1] + decoder_out_dim = joiner.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out)) + j1 = joiner(encoder_out, decoder_out) + j2 = traced_joiner(encoder_out, decoder_out) + assert torch.equal(j1, j2), (j1 - j2).abs().mean() + + _test_encoder() + _test_decoder() + _test_joiner() + + +def main(): + test_model() + test_model_jit_trace() + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py new file mode 100755 index 000000000..a9bc9c2a2 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py @@ -0,0 +1,1256 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from commonvoice_fr import CommonVoiceAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/fr/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + commonvoice = CommonVoiceAsrDataModule(args) + + train_cuts = commonvoice.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = commonvoice.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = commonvoice.dev_cuts() + valid_dl = commonvoice.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py new file mode 100755 index 000000000..c09c9537c --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -0,0 +1,1257 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from commonvoice_fr import CommonVoiceAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer2 import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + is_pnnx=True, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + commonvoice = CommonVoiceAsrDataModule(args) + + train_cuts = commonvoice.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = commonvoice.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = commonvoice.dev_cuts() + valid_dl = commonvoice.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py new file mode 120000 index 000000000..ec183baa7 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py new file mode 120000 index 000000000..12dbda888 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py deleted file mode 100755 index 29a2cd7f7..000000000 --- a/icefall/shared/convert-k2-to-openfst.py +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script takes as input an FST in k2 format and convert it -to an FST in OpenFST format. - -The generated FST is saved into a binary file and its type is -StdVectorFst. - -Usage examples: -(1) Convert an acceptor - - ./convert-k2-to-openfst.py in.pt binary.fst - -(2) Convert a transducer - - ./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import kaldifst.utils -import torch - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--olabels", - type=str, - default=None, - help="""If not empty, the input FST is assumed to be a transducer - and we use its attribute specified by "olabels" as the output labels. - """, - ) - parser.add_argument( - "input_filename", - type=str, - help="Path to the input FST in k2 format", - ) - - parser.add_argument( - "output_filename", - type=str, - help="Path to the output FST in OpenFst format", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - logging.info(f"{vars(args)}") - - input_filename = args.input_filename - output_filename = args.output_filename - olabels = args.olabels - - if Path(output_filename).is_file(): - logging.info(f"{output_filename} already exists - skipping") - return - - assert Path(input_filename).is_file(), f"{input_filename} does not exist" - logging.info(f"Loading {input_filename}") - k2_fst = k2.Fsa.from_dict(torch.load(input_filename)) - if olabels: - assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}" - - p = Path(output_filename).parent - if not p.is_dir(): - logging.info(f"Creating {p}") - p.mkdir(parents=True) - - logging.info("Converting (May take some time if the input FST is large)") - fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels) - logging.info(f"Saving to {output_filename}") - fst.write(output_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py new file mode 120000 index 000000000..24efe5eae --- /dev/null +++ b/icefall/shared/convert-k2-to-openfst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/shared/convert-k2-to-openfst.py \ No newline at end of file diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py deleted file mode 100755 index b1ebee9ea..000000000 --- a/icefall/shared/ngram_entropy_pruning.py +++ /dev/null @@ -1,630 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# -# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -./ngram_entropy_pruning.py \ - -threshold 1e-8 \ - -lm download/lm/4gram.arpa \ - -write-lm download/lm/4gram_pruned_1e8.arpa - -This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`. -This is an implementation of ``Entropy-based Pruning of Backoff Language Models'' -in the same way as SRILM. -""" - - -import argparse -import gzip -import logging -import math -import re -from collections import OrderedDict, defaultdict -from enum import Enum, unique -from io import StringIO - -parser = argparse.ArgumentParser( - description=""" - Prune an n-gram language model based on the relative entropy - between the original and the pruned model, based on Andreas Stolcke's paper. - An n-gram entry is removed, if the removal causes (training set) perplexity - of the model to increase by less than threshold relative. - - The command takes an arpa file and a pruning threshold as input, - and outputs a pruned arpa file. - """ -) -parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram") -parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file") -parser.add_argument( - "-write-lm", type=str, default=None, help="Path to output arpa file after pruning" -) -parser.add_argument( - "-minorder", - type=int, - default=1, - help="The minorder parameter limits pruning to ngrams of that length and above.", -) -parser.add_argument( - "-encoding", type=str, default="utf-8", help="Encoding of the arpa file" -) -parser.add_argument( - "-verbose", - type=int, - default=2, - choices=[0, 1, 2, 3, 4, 5], - help="Verbose level, where 0 is most noisy; 5 is most silent", -) -args = parser.parse_args() - -default_encoding = args.encoding -logging.basicConfig( - format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s", - level=args.verbose * 10, -) - - -class Context(dict): - """ - This class stores data for a context h. - It behaves like a python dict object, except that it has several - additional attributes. - """ - - def __init__(self): - super().__init__() - self.log_bo = None - - -class Arpa: - """ - This is a class that implement the data structure of an APRA LM. - It (as well as some other classes) is modified based on the library - by Stefan Fischer: - https://github.com/sfischer13/python-arpa - """ - - UNK = "" - SOS = "" - EOS = "" - FLOAT_NDIGITS = 7 - base = 10 - - @staticmethod - def _check_input(my_input): - if not my_input: - raise ValueError - elif isinstance(my_input, tuple): - return my_input - elif isinstance(my_input, list): - return tuple(my_input) - elif isinstance(my_input, str): - return tuple(my_input.strip().split(" ")) - else: - raise ValueError - - @staticmethod - def _check_word(input_word): - if not isinstance(input_word, str): - raise ValueError - if " " in input_word: - raise ValueError - - def _replace_unks(self, words): - return tuple((w if w in self else self._unk) for w in words) - - def __init__(self, path=None, encoding=None, unk=None): - self._counts = OrderedDict() - self._ngrams = ( - OrderedDict() - ) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w) - self._vocabulary = set() - if unk is None: - self._unk = self.UNK - - if path is not None: - self.loadf(path, encoding) - - def __contains__(self, ngram): - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h] - - def contains_word(self, word): - self._check_word(word) - return word in self._vocabulary - - def add_count(self, order, count): - self._counts[order] = count - self._ngrams[order - 1] = defaultdict(Context) - - def update_counts(self): - for order in range(1, self.order() + 1): - count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()]) - if count > 0: - self._counts[order] = count - - def add_entry(self, ngram, p, bo=None, order=None): - # Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3") - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - - # Note that p and bo here are in fact in the log domain (self.base = 10) - h_context = self._ngrams[len(h)][h] - h_context[w] = p - if bo is not None: - self._ngrams[len(ngram)][ngram].log_bo = bo - - for word in ngram: - self._vocabulary.add(word) - - def counts(self): - return sorted(self._counts.items()) - - def order(self): - return max(self._counts.keys(), default=None) - - def vocabulary(self, sort=True): - if sort: - return sorted(self._vocabulary) - else: - return self._vocabulary - - def _entries(self, order): - return ( - self._entry(h, w) - for h, wlist in self._ngrams[order - 1].items() - for w in wlist - ) - - def _entry(self, h, w): - # return the entry for the ngram (h, w) - ngram = h + (w,) - log_p = self._ngrams[len(h)][h][w] - log_bo = self._log_bo(ngram) - if log_bo is not None: - return ( - round(log_p, self.FLOAT_NDIGITS), - ngram, - round(log_bo, self.FLOAT_NDIGITS), - ) - else: - return round(log_p, self.FLOAT_NDIGITS), ngram - - def _log_bo(self, ngram): - if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]: - return self._ngrams[len(ngram)][ngram].log_bo - else: - return None - - def _log_p(self, ngram): - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]: - return self._ngrams[len(h)][h][w] - else: - return None - - def log_p_raw(self, ngram): - log_p = self._log_p(ngram) - if log_p is not None: - return log_p - else: - if len(ngram) == 1: - raise KeyError - else: - log_bo = self._log_bo(ngram[:-1]) - if log_bo is None: - log_bo = 0 - return log_bo + self.log_p_raw(ngram[1:]) - - def log_joint_prob(self, sequence): - # Compute the joint prob of the sequence based on the chain rule - # Note that sequence should be a tuple of strings - # - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527 - - log_joint_p = 0 - seq = sequence - while len(seq) > 0: - log_joint_p += self.log_p_raw(seq) - seq = seq[:-1] - - # If we're computing the marginal probability of the unigram - # context we have to look up instead since the former - # has prob = 0. - if len(seq) == 1 and seq[0] == self.SOS: - seq = (self.EOS,) - - return log_joint_p - - def set_new_context(self, h): - old_context = self._ngrams[len(h)][h] - self._ngrams[len(h)][h] = Context() - return old_context - - def log_p(self, ngram): - words = self._check_input(ngram) - if self._unk: - words = self._replace_unks(words) - return self.log_p_raw(words) - - def log_s(self, sentence, sos=SOS, eos=EOS): - words = self._check_input(sentence) - if self._unk: - words = self._replace_unks(words) - if sos: - words = (sos,) + words - if eos: - words = words + (eos,) - result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1)) - if sos: - result = result - self.log_p_raw(words[:1]) - return result - - def p(self, ngram): - return self.base ** self.log_p(ngram) - - def s(self, sentence): - return self.base ** self.log_s(sentence) - - def write(self, fp): - fp.write("\n\\data\\\n") - for order, count in self.counts(): - fp.write("ngram {}={}\n".format(order, count)) - fp.write("\n") - for order, _ in self.counts(): - fp.write("\\{}-grams:\n".format(order)) - for e in self._entries(order): - prob = e[0] - ngram = " ".join(e[1]) - if len(e) == 2: - fp.write("{}\t{}\n".format(prob, ngram)) - elif len(e) == 3: - backoff = e[2] - fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff)) - else: - raise ValueError - fp.write("\n") - fp.write("\\end\\\n") - - -class ArpaParser: - """ - This is a class that implement a parser of an arpa file - """ - - @unique - class State(Enum): - DATA = 1 - COUNT = 2 - HEADER = 3 - ENTRY = 4 - - re_count = re.compile(r"^ngram (\d+)=(\d+)$") - re_header = re.compile(r"^\\(\d+)-grams:$") - re_entry = re.compile( - "^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)" - "\t" - "(\\S+( \\S+)*)" - "(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$" - ) - - def _parse(self, fp): - self._result = [] - self._state = self.State.DATA - self._tmp_model = None - self._tmp_order = None - for line in fp: - line = line.strip() - if self._state == self.State.DATA: - self._data(line) - elif self._state == self.State.COUNT: - self._count(line) - elif self._state == self.State.HEADER: - self._header(line) - elif self._state == self.State.ENTRY: - self._entry(line) - if self._state != self.State.DATA: - raise Exception(line) - return self._result - - def _data(self, line): - if line == "\\data\\": - self._state = self.State.COUNT - self._tmp_model = Arpa() - else: - pass # skip comment line - - def _count(self, line): - match = self.re_count.match(line) - if match: - order = match.group(1) - count = match.group(2) - self._tmp_model.add_count(int(order), int(count)) - elif not line: - self._state = self.State.HEADER # there are no counts - else: - raise Exception(line) - - def _header(self, line): - match = self.re_header.match(line) - if match: - self._state = self.State.ENTRY - self._tmp_order = int(match.group(1)) - elif line == "\\end\\": - self._result.append(self._tmp_model) - self._state = self.State.DATA - self._tmp_model = None - self._tmp_order = None - elif not line: - pass # skip empty line - else: - raise Exception(line) - - def _entry(self, line): - match = self.re_entry.match(line) - if match: - p = self._float_or_int(match.group(1)) - ngram = tuple(match.group(4).split(" ")) - bo_match = match.group(7) - bo = self._float_or_int(bo_match) if bo_match else None - self._tmp_model.add_entry(ngram, p, bo, self._tmp_order) - elif not line: - self._state = self.State.HEADER # last entry - else: - raise Exception(line) - - @staticmethod - def _float_or_int(s): - f = float(s) - i = int(f) - if str(i) == s: # don't drop trailing ".0" - return i - else: - return f - - def load(self, fp): - """Deserialize fp (a file-like object) to a Python object.""" - return self._parse(fp) - - def loadf(self, path, encoding=None): - """Deserialize path (.arpa, .gz) to a Python object.""" - path = str(path) - if path.endswith(".gz"): - with gzip.open(path, mode="rt", encoding=encoding) as f: - return self.load(f) - else: - with open(path, mode="rt", encoding=encoding) as f: - return self.load(f) - - def loads(self, s): - """Deserialize s (a str) to a Python object.""" - with StringIO(s) as f: - return self.load(f) - - def dump(self, obj, fp): - """Serialize obj to fp (a file-like object) in ARPA format.""" - obj.write(fp) - - def dumpf(self, obj, path, encoding=None): - """Serialize obj to path in ARPA format (.arpa, .gz).""" - path = str(path) - if path.endswith(".gz"): - with gzip.open(path, mode="wt", encoding=encoding) as f: - return self.dump(obj, f) - else: - with open(path, mode="wt", encoding=encoding) as f: - self.dump(obj, f) - - def dumps(self, obj): - """Serialize obj to an ARPA formatted str.""" - with StringIO() as f: - self.dump(obj, f) - return f.getvalue() - - -def add_log_p(prev_log_sum, log_p, base): - return math.log(base**log_p + base**prev_log_sum, base) - - -def compute_numerator_denominator(lm, h): - log_sum_seen_h = -math.inf - log_sum_seen_h_lower = -math.inf - base = lm.base - for w, log_p in lm._ngrams[len(h)][h].items(): - log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base) - - ngram = h + (w,) - log_p_lower = lm.log_p_raw(ngram[1:]) - log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base) - - numerator = 1.0 - base**log_sum_seen_h - denominator = 1.0 - base**log_sum_seen_h_lower - return numerator, denominator - - -def prune(lm, threshold, minorder): - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330 - - for i in range( - lm.order(), max(minorder - 1, 1), -1 - ): # i is the order of the ngram (h, w) - logging.info("processing %d-grams ..." % i) - count_pruned_ngrams = 0 - - h_dict = lm._ngrams[i - 1] - for h in list(h_dict.keys()): - # old backoff weight, BOW(h) - log_bow = lm._log_bo(h) - if log_bow is None: - log_bow = 0 - - # Compute numerator and denominator of the backoff weight, - # so that we can quickly compute the BOW adjustment due to - # leaving out one prob. - numerator, denominator = compute_numerator_denominator(lm, h) - - # assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5 - - # Compute the marginal probability of the context, P(h) - h_log_p = lm.log_joint_prob(h) - - all_pruned = True - pruned_w_set = set() - - for w, log_p in h_dict[h].items(): - ngram = h + (w,) - - # lower-order estimate for ngramProb, P(w|h') - backoff_prob = lm.log_p_raw(ngram[1:]) - - # Compute BOW after removing ngram, BOW'(h) - new_log_bow = math.log( - numerator + lm.base**log_p, lm.base - ) - math.log(denominator + lm.base**backoff_prob, lm.base) - - # Compute change in entropy due to removal of ngram - delta_prob = backoff_prob + new_log_bow - log_p - delta_entropy = -(lm.base**h_log_p) * ( - (lm.base**log_p) * delta_prob - + numerator * (new_log_bow - log_bow) - ) - - # compute relative change in model (training set) perplexity - perp_change = lm.base**delta_entropy - 1.0 - - pruned = threshold > 0 and perp_change < threshold - - # Make sure we don't prune ngrams whose backoff nodes are needed - if ( - pruned - and len(ngram) in lm._ngrams - and len(lm._ngrams[len(ngram)][ngram]) > 0 - ): - pruned = False - - logging.debug( - "CONTEXT " - + str(h) - + " WORD " - + w - + " CONTEXTPROB %f " % h_log_p - + " OLDPROB %f " % log_p - + " NEWPROB %f " % (backoff_prob + new_log_bow) - + " DELTA-H %f " % delta_entropy - + " DELTA-LOGP %f " % delta_prob - + " PPL-CHANGE %f " % perp_change - + " PRUNED " - + str(pruned) - ) - - if pruned: - pruned_w_set.add(w) - count_pruned_ngrams += 1 - else: - all_pruned = False - - # If we removed all ngrams for this context we can - # remove the context itself, but only if the present - # context is not a prefix to a longer one. - if all_pruned and len(pruned_w_set) == len(h_dict[h]): - del h_dict[ - h - ] # this context h is no longer needed, as its ngram prob is stored at its own context h' - elif len(pruned_w_set) > 0: - # The pruning for this context h is actually done here - old_context = lm.set_new_context(h) - - for w, p_w in old_context.items(): - if w not in pruned_w_set: - lm.add_entry( - h + (w,), p_w - ) # the entry hw is stored at the context h - - # We need to recompute the back-off weight, but - # this can only be done after completing the pruning - # of the lower-order ngrams. - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124 - - logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i)) - - # recompute backoff weights - for i in range( - max(minorder - 1, 1) + 1, lm.order() + 1 - ): # be careful of this order: from low- to high-order - for h in lm._ngrams[i - 1]: - numerator, denominator = compute_numerator_denominator(lm, h) - new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base) - lm._ngrams[len(h)][h].log_bo = new_log_bow - - # update counts - lm.update_counts() - - return - - -def check_h_is_valid(lm, h): - sum_under_h = sum( - [lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)] - ) - if abs(sum_under_h - 1.0) > 1e-6: - logging.info("warning: %s %f" % (str(h), sum_under_h)) - return False - else: - return True - - -def validate_lm(lm): - # sanity check if the conditional probability sums to one under each context h - for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w) - logging.info("validating %d-grams ..." % i) - h_dict = lm._ngrams[i - 1] - for h in h_dict.keys(): - check_h_is_valid(lm, h) - - -def compare_two_apras(path1, path2): - pass - - -if __name__ == "__main__": - # load an arpa file - logging.info("Loading the arpa file from %s" % args.lm) - parser = ArpaParser() - models = parser.loadf(args.lm, encoding=default_encoding) - lm = models[0] # ARPA files may contain several models. - logging.info("Stats before pruning:") - for i, cnt in lm.counts(): - logging.info("ngram %d=%d" % (i, cnt)) - - # prune it, the language model will be modified in-place - logging.info("Start pruning the model with threshold=%.3E..." % args.threshold) - prune(lm, args.threshold, args.minorder) - - # validate_lm(lm) - - # write the arpa language model to a file - logging.info("Stats after pruning:") - for i, cnt in lm.counts(): - logging.info("ngram %d=%d" % (i, cnt)) - logging.info("Saving the pruned arpa file to %s" % args.write_lm) - parser.dumpf(lm, args.write_lm, encoding=default_encoding) - logging.info("Done.") diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py new file mode 120000 index 000000000..0e14ac415 --- /dev/null +++ b/icefall/shared/ngram_entropy_pruning.py @@ -0,0 +1 @@ +../../../librispeech/ASR/shared/ngram_entropy_pruning.py \ No newline at end of file diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh deleted file mode 100755 index 71fb9e5ea..000000000 --- a/icefall/shared/parse_options.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env bash - -# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); -# Arnab Ghoshal, Karel Vesely - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -# Parse command-line options. -# To be sourced by another script (as in ". parse_options.sh"). -# Option format is: --option-name arg -# and shell variable "option_name" gets set to value "arg." -# The exception is --help, which takes no arguments, but prints the -# $help_message variable (if defined). - - -### -### The --config file options have lower priority to command line -### options, so we need to import them first... -### - -# Now import all the configs specified by command-line, in left-to-right order -for ((argpos=1; argpos<$#; argpos++)); do - if [ "${!argpos}" == "--config" ]; then - argpos_plus1=$((argpos+1)) - config=${!argpos_plus1} - [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 - . $config # source the config file. - fi -done - - -### -### Now we process the command line options -### -while true; do - [ -z "${1:-}" ] && break; # break if there are no arguments - case "$1" in - # If the enclosing script is called with --help option, print the help - # message and exit. Scripts should put help messages in $help_message - --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; - else printf "$help_message\n" 1>&2 ; fi; - exit 0 ;; - --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" - exit 1 ;; - # If the first command-line argument begins with "--" (e.g. --foo-bar), - # then work out the variable name as $name, which will equal "foo_bar". - --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; - # Next we test whether the variable in question is undefned-- if so it's - # an invalid option and we die. Note: $0 evaluates to the name of the - # enclosing script. - # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar - # is undefined. We then have to wrap this test inside "eval" because - # foo_bar is itself inside a variable ($name). - eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; - - oldval="`eval echo \\$$name`"; - # Work out whether we seem to be expecting a Boolean argument. - if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then - was_bool=true; - else - was_bool=false; - fi - - # Set the variable to the right value-- the escaped quotes make it work if - # the option had spaces, like --cmd "queue.pl -sync y" - eval $name=\"$2\"; - - # Check that Boolean-valued arguments are really Boolean. - if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then - echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 - exit 1; - fi - shift 2; - ;; - *) break; - esac -done - - -# Check for an empty argument to the --cmd option, which can easily occur as a -# result of scripting errors. -[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; - - -true; # so this script returns exit code 0. diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh new file mode 120000 index 000000000..e4665e7de --- /dev/null +++ b/icefall/shared/parse_options.sh @@ -0,0 +1 @@ +../../../librispeech/ASR/shared/parse_options.sh \ No newline at end of file From 6d275ddf9fdd67a32b79b93d70fedffe4b156d5c Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 10 Nov 2023 14:45:16 +0800 Subject: [PATCH 100/113] fixed broken softlinks (#1381) * removed broken softlinks * fixed dependencies * fixed file permission --- icefall/shared/convert-k2-to-openfst.py | 103 +++- icefall/shared/ngram_entropy_pruning.py | 631 +++++++++++++++++++++++- icefall/shared/parse_options.sh | 98 +++- 3 files changed, 829 insertions(+), 3 deletions(-) mode change 120000 => 100755 icefall/shared/convert-k2-to-openfst.py mode change 120000 => 100755 icefall/shared/ngram_entropy_pruning.py mode change 120000 => 100755 icefall/shared/parse_options.sh diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py deleted file mode 120000 index 24efe5eae..000000000 --- a/icefall/shared/convert-k2-to-openfst.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/shared/convert-k2-to-openfst.py \ No newline at end of file diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py new file mode 100755 index 000000000..29a2cd7f7 --- /dev/null +++ b/icefall/shared/convert-k2-to-openfst.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes as input an FST in k2 format and convert it +to an FST in OpenFST format. + +The generated FST is saved into a binary file and its type is +StdVectorFst. + +Usage examples: +(1) Convert an acceptor + + ./convert-k2-to-openfst.py in.pt binary.fst + +(2) Convert a transducer + + ./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import kaldifst.utils +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--olabels", + type=str, + default=None, + help="""If not empty, the input FST is assumed to be a transducer + and we use its attribute specified by "olabels" as the output labels. + """, + ) + parser.add_argument( + "input_filename", + type=str, + help="Path to the input FST in k2 format", + ) + + parser.add_argument( + "output_filename", + type=str, + help="Path to the output FST in OpenFst format", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + logging.info(f"{vars(args)}") + + input_filename = args.input_filename + output_filename = args.output_filename + olabels = args.olabels + + if Path(output_filename).is_file(): + logging.info(f"{output_filename} already exists - skipping") + return + + assert Path(input_filename).is_file(), f"{input_filename} does not exist" + logging.info(f"Loading {input_filename}") + k2_fst = k2.Fsa.from_dict(torch.load(input_filename)) + if olabels: + assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}" + + p = Path(output_filename).parent + if not p.is_dir(): + logging.info(f"Creating {p}") + p.mkdir(parents=True) + + logging.info("Converting (May take some time if the input FST is large)") + fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels) + logging.info(f"Saving to {output_filename}") + fst.write(output_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py deleted file mode 120000 index 0e14ac415..000000000 --- a/icefall/shared/ngram_entropy_pruning.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/shared/ngram_entropy_pruning.py \ No newline at end of file diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py new file mode 100755 index 000000000..b1ebee9ea --- /dev/null +++ b/icefall/shared/ngram_entropy_pruning.py @@ -0,0 +1,630 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +./ngram_entropy_pruning.py \ + -threshold 1e-8 \ + -lm download/lm/4gram.arpa \ + -write-lm download/lm/4gram_pruned_1e8.arpa + +This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`. +This is an implementation of ``Entropy-based Pruning of Backoff Language Models'' +in the same way as SRILM. +""" + + +import argparse +import gzip +import logging +import math +import re +from collections import OrderedDict, defaultdict +from enum import Enum, unique +from io import StringIO + +parser = argparse.ArgumentParser( + description=""" + Prune an n-gram language model based on the relative entropy + between the original and the pruned model, based on Andreas Stolcke's paper. + An n-gram entry is removed, if the removal causes (training set) perplexity + of the model to increase by less than threshold relative. + + The command takes an arpa file and a pruning threshold as input, + and outputs a pruned arpa file. + """ +) +parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram") +parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file") +parser.add_argument( + "-write-lm", type=str, default=None, help="Path to output arpa file after pruning" +) +parser.add_argument( + "-minorder", + type=int, + default=1, + help="The minorder parameter limits pruning to ngrams of that length and above.", +) +parser.add_argument( + "-encoding", type=str, default="utf-8", help="Encoding of the arpa file" +) +parser.add_argument( + "-verbose", + type=int, + default=2, + choices=[0, 1, 2, 3, 4, 5], + help="Verbose level, where 0 is most noisy; 5 is most silent", +) +args = parser.parse_args() + +default_encoding = args.encoding +logging.basicConfig( + format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s", + level=args.verbose * 10, +) + + +class Context(dict): + """ + This class stores data for a context h. + It behaves like a python dict object, except that it has several + additional attributes. + """ + + def __init__(self): + super().__init__() + self.log_bo = None + + +class Arpa: + """ + This is a class that implement the data structure of an APRA LM. + It (as well as some other classes) is modified based on the library + by Stefan Fischer: + https://github.com/sfischer13/python-arpa + """ + + UNK = "" + SOS = "" + EOS = "" + FLOAT_NDIGITS = 7 + base = 10 + + @staticmethod + def _check_input(my_input): + if not my_input: + raise ValueError + elif isinstance(my_input, tuple): + return my_input + elif isinstance(my_input, list): + return tuple(my_input) + elif isinstance(my_input, str): + return tuple(my_input.strip().split(" ")) + else: + raise ValueError + + @staticmethod + def _check_word(input_word): + if not isinstance(input_word, str): + raise ValueError + if " " in input_word: + raise ValueError + + def _replace_unks(self, words): + return tuple((w if w in self else self._unk) for w in words) + + def __init__(self, path=None, encoding=None, unk=None): + self._counts = OrderedDict() + self._ngrams = ( + OrderedDict() + ) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w) + self._vocabulary = set() + if unk is None: + self._unk = self.UNK + + if path is not None: + self.loadf(path, encoding) + + def __contains__(self, ngram): + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h] + + def contains_word(self, word): + self._check_word(word) + return word in self._vocabulary + + def add_count(self, order, count): + self._counts[order] = count + self._ngrams[order - 1] = defaultdict(Context) + + def update_counts(self): + for order in range(1, self.order() + 1): + count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()]) + if count > 0: + self._counts[order] = count + + def add_entry(self, ngram, p, bo=None, order=None): + # Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3") + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + + # Note that p and bo here are in fact in the log domain (self.base = 10) + h_context = self._ngrams[len(h)][h] + h_context[w] = p + if bo is not None: + self._ngrams[len(ngram)][ngram].log_bo = bo + + for word in ngram: + self._vocabulary.add(word) + + def counts(self): + return sorted(self._counts.items()) + + def order(self): + return max(self._counts.keys(), default=None) + + def vocabulary(self, sort=True): + if sort: + return sorted(self._vocabulary) + else: + return self._vocabulary + + def _entries(self, order): + return ( + self._entry(h, w) + for h, wlist in self._ngrams[order - 1].items() + for w in wlist + ) + + def _entry(self, h, w): + # return the entry for the ngram (h, w) + ngram = h + (w,) + log_p = self._ngrams[len(h)][h][w] + log_bo = self._log_bo(ngram) + if log_bo is not None: + return ( + round(log_p, self.FLOAT_NDIGITS), + ngram, + round(log_bo, self.FLOAT_NDIGITS), + ) + else: + return round(log_p, self.FLOAT_NDIGITS), ngram + + def _log_bo(self, ngram): + if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]: + return self._ngrams[len(ngram)][ngram].log_bo + else: + return None + + def _log_p(self, ngram): + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]: + return self._ngrams[len(h)][h][w] + else: + return None + + def log_p_raw(self, ngram): + log_p = self._log_p(ngram) + if log_p is not None: + return log_p + else: + if len(ngram) == 1: + raise KeyError + else: + log_bo = self._log_bo(ngram[:-1]) + if log_bo is None: + log_bo = 0 + return log_bo + self.log_p_raw(ngram[1:]) + + def log_joint_prob(self, sequence): + # Compute the joint prob of the sequence based on the chain rule + # Note that sequence should be a tuple of strings + # + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527 + + log_joint_p = 0 + seq = sequence + while len(seq) > 0: + log_joint_p += self.log_p_raw(seq) + seq = seq[:-1] + + # If we're computing the marginal probability of the unigram + # context we have to look up instead since the former + # has prob = 0. + if len(seq) == 1 and seq[0] == self.SOS: + seq = (self.EOS,) + + return log_joint_p + + def set_new_context(self, h): + old_context = self._ngrams[len(h)][h] + self._ngrams[len(h)][h] = Context() + return old_context + + def log_p(self, ngram): + words = self._check_input(ngram) + if self._unk: + words = self._replace_unks(words) + return self.log_p_raw(words) + + def log_s(self, sentence, sos=SOS, eos=EOS): + words = self._check_input(sentence) + if self._unk: + words = self._replace_unks(words) + if sos: + words = (sos,) + words + if eos: + words = words + (eos,) + result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1)) + if sos: + result = result - self.log_p_raw(words[:1]) + return result + + def p(self, ngram): + return self.base ** self.log_p(ngram) + + def s(self, sentence): + return self.base ** self.log_s(sentence) + + def write(self, fp): + fp.write("\n\\data\\\n") + for order, count in self.counts(): + fp.write("ngram {}={}\n".format(order, count)) + fp.write("\n") + for order, _ in self.counts(): + fp.write("\\{}-grams:\n".format(order)) + for e in self._entries(order): + prob = e[0] + ngram = " ".join(e[1]) + if len(e) == 2: + fp.write("{}\t{}\n".format(prob, ngram)) + elif len(e) == 3: + backoff = e[2] + fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff)) + else: + raise ValueError + fp.write("\n") + fp.write("\\end\\\n") + + +class ArpaParser: + """ + This is a class that implement a parser of an arpa file + """ + + @unique + class State(Enum): + DATA = 1 + COUNT = 2 + HEADER = 3 + ENTRY = 4 + + re_count = re.compile(r"^ngram (\d+)=(\d+)$") + re_header = re.compile(r"^\\(\d+)-grams:$") + re_entry = re.compile( + "^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)" + "\t" + "(\\S+( \\S+)*)" + "(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$" + ) + + def _parse(self, fp): + self._result = [] + self._state = self.State.DATA + self._tmp_model = None + self._tmp_order = None + for line in fp: + line = line.strip() + if self._state == self.State.DATA: + self._data(line) + elif self._state == self.State.COUNT: + self._count(line) + elif self._state == self.State.HEADER: + self._header(line) + elif self._state == self.State.ENTRY: + self._entry(line) + if self._state != self.State.DATA: + raise Exception(line) + return self._result + + def _data(self, line): + if line == "\\data\\": + self._state = self.State.COUNT + self._tmp_model = Arpa() + else: + pass # skip comment line + + def _count(self, line): + match = self.re_count.match(line) + if match: + order = match.group(1) + count = match.group(2) + self._tmp_model.add_count(int(order), int(count)) + elif not line: + self._state = self.State.HEADER # there are no counts + else: + raise Exception(line) + + def _header(self, line): + match = self.re_header.match(line) + if match: + self._state = self.State.ENTRY + self._tmp_order = int(match.group(1)) + elif line == "\\end\\": + self._result.append(self._tmp_model) + self._state = self.State.DATA + self._tmp_model = None + self._tmp_order = None + elif not line: + pass # skip empty line + else: + raise Exception(line) + + def _entry(self, line): + match = self.re_entry.match(line) + if match: + p = self._float_or_int(match.group(1)) + ngram = tuple(match.group(4).split(" ")) + bo_match = match.group(7) + bo = self._float_or_int(bo_match) if bo_match else None + self._tmp_model.add_entry(ngram, p, bo, self._tmp_order) + elif not line: + self._state = self.State.HEADER # last entry + else: + raise Exception(line) + + @staticmethod + def _float_or_int(s): + f = float(s) + i = int(f) + if str(i) == s: # don't drop trailing ".0" + return i + else: + return f + + def load(self, fp): + """Deserialize fp (a file-like object) to a Python object.""" + return self._parse(fp) + + def loadf(self, path, encoding=None): + """Deserialize path (.arpa, .gz) to a Python object.""" + path = str(path) + if path.endswith(".gz"): + with gzip.open(path, mode="rt", encoding=encoding) as f: + return self.load(f) + else: + with open(path, mode="rt", encoding=encoding) as f: + return self.load(f) + + def loads(self, s): + """Deserialize s (a str) to a Python object.""" + with StringIO(s) as f: + return self.load(f) + + def dump(self, obj, fp): + """Serialize obj to fp (a file-like object) in ARPA format.""" + obj.write(fp) + + def dumpf(self, obj, path, encoding=None): + """Serialize obj to path in ARPA format (.arpa, .gz).""" + path = str(path) + if path.endswith(".gz"): + with gzip.open(path, mode="wt", encoding=encoding) as f: + return self.dump(obj, f) + else: + with open(path, mode="wt", encoding=encoding) as f: + self.dump(obj, f) + + def dumps(self, obj): + """Serialize obj to an ARPA formatted str.""" + with StringIO() as f: + self.dump(obj, f) + return f.getvalue() + + +def add_log_p(prev_log_sum, log_p, base): + return math.log(base**log_p + base**prev_log_sum, base) + + +def compute_numerator_denominator(lm, h): + log_sum_seen_h = -math.inf + log_sum_seen_h_lower = -math.inf + base = lm.base + for w, log_p in lm._ngrams[len(h)][h].items(): + log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base) + + ngram = h + (w,) + log_p_lower = lm.log_p_raw(ngram[1:]) + log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base) + + numerator = 1.0 - base**log_sum_seen_h + denominator = 1.0 - base**log_sum_seen_h_lower + return numerator, denominator + + +def prune(lm, threshold, minorder): + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330 + + for i in range( + lm.order(), max(minorder - 1, 1), -1 + ): # i is the order of the ngram (h, w) + logging.info("processing %d-grams ..." % i) + count_pruned_ngrams = 0 + + h_dict = lm._ngrams[i - 1] + for h in list(h_dict.keys()): + # old backoff weight, BOW(h) + log_bow = lm._log_bo(h) + if log_bow is None: + log_bow = 0 + + # Compute numerator and denominator of the backoff weight, + # so that we can quickly compute the BOW adjustment due to + # leaving out one prob. + numerator, denominator = compute_numerator_denominator(lm, h) + + # assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5 + + # Compute the marginal probability of the context, P(h) + h_log_p = lm.log_joint_prob(h) + + all_pruned = True + pruned_w_set = set() + + for w, log_p in h_dict[h].items(): + ngram = h + (w,) + + # lower-order estimate for ngramProb, P(w|h') + backoff_prob = lm.log_p_raw(ngram[1:]) + + # Compute BOW after removing ngram, BOW'(h) + new_log_bow = math.log( + numerator + lm.base**log_p, lm.base + ) - math.log(denominator + lm.base**backoff_prob, lm.base) + + # Compute change in entropy due to removal of ngram + delta_prob = backoff_prob + new_log_bow - log_p + delta_entropy = -(lm.base**h_log_p) * ( + (lm.base**log_p) * delta_prob + + numerator * (new_log_bow - log_bow) + ) + + # compute relative change in model (training set) perplexity + perp_change = lm.base**delta_entropy - 1.0 + + pruned = threshold > 0 and perp_change < threshold + + # Make sure we don't prune ngrams whose backoff nodes are needed + if ( + pruned + and len(ngram) in lm._ngrams + and len(lm._ngrams[len(ngram)][ngram]) > 0 + ): + pruned = False + + logging.debug( + "CONTEXT " + + str(h) + + " WORD " + + w + + " CONTEXTPROB %f " % h_log_p + + " OLDPROB %f " % log_p + + " NEWPROB %f " % (backoff_prob + new_log_bow) + + " DELTA-H %f " % delta_entropy + + " DELTA-LOGP %f " % delta_prob + + " PPL-CHANGE %f " % perp_change + + " PRUNED " + + str(pruned) + ) + + if pruned: + pruned_w_set.add(w) + count_pruned_ngrams += 1 + else: + all_pruned = False + + # If we removed all ngrams for this context we can + # remove the context itself, but only if the present + # context is not a prefix to a longer one. + if all_pruned and len(pruned_w_set) == len(h_dict[h]): + del h_dict[ + h + ] # this context h is no longer needed, as its ngram prob is stored at its own context h' + elif len(pruned_w_set) > 0: + # The pruning for this context h is actually done here + old_context = lm.set_new_context(h) + + for w, p_w in old_context.items(): + if w not in pruned_w_set: + lm.add_entry( + h + (w,), p_w + ) # the entry hw is stored at the context h + + # We need to recompute the back-off weight, but + # this can only be done after completing the pruning + # of the lower-order ngrams. + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124 + + logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i)) + + # recompute backoff weights + for i in range( + max(minorder - 1, 1) + 1, lm.order() + 1 + ): # be careful of this order: from low- to high-order + for h in lm._ngrams[i - 1]: + numerator, denominator = compute_numerator_denominator(lm, h) + new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base) + lm._ngrams[len(h)][h].log_bo = new_log_bow + + # update counts + lm.update_counts() + + return + + +def check_h_is_valid(lm, h): + sum_under_h = sum( + [lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)] + ) + if abs(sum_under_h - 1.0) > 1e-6: + logging.info("warning: %s %f" % (str(h), sum_under_h)) + return False + else: + return True + + +def validate_lm(lm): + # sanity check if the conditional probability sums to one under each context h + for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w) + logging.info("validating %d-grams ..." % i) + h_dict = lm._ngrams[i - 1] + for h in h_dict.keys(): + check_h_is_valid(lm, h) + + +def compare_two_apras(path1, path2): + pass + + +if __name__ == "__main__": + # load an arpa file + logging.info("Loading the arpa file from %s" % args.lm) + parser = ArpaParser() + models = parser.loadf(args.lm, encoding=default_encoding) + lm = models[0] # ARPA files may contain several models. + logging.info("Stats before pruning:") + for i, cnt in lm.counts(): + logging.info("ngram %d=%d" % (i, cnt)) + + # prune it, the language model will be modified in-place + logging.info("Start pruning the model with threshold=%.3E..." % args.threshold) + prune(lm, args.threshold, args.minorder) + + # validate_lm(lm) + + # write the arpa language model to a file + logging.info("Stats after pruning:") + for i, cnt in lm.counts(): + logging.info("ngram %d=%d" % (i, cnt)) + logging.info("Saving the pruned arpa file to %s" % args.write_lm) + parser.dumpf(lm, args.write_lm, encoding=default_encoding) + logging.info("Done.") diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh deleted file mode 120000 index e4665e7de..000000000 --- a/icefall/shared/parse_options.sh +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/shared/parse_options.sh \ No newline at end of file diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh new file mode 100755 index 000000000..71fb9e5ea --- /dev/null +++ b/icefall/shared/parse_options.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### Now we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. From 59c943878ff7f3d741a29d743b8560e342fa892d Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Thu, 16 Nov 2023 07:38:31 +0100 Subject: [PATCH 101/113] add the `voxpopuli` recipe (#1374) * add the `voxpopuli` recipe - this is the data preparation - there is no ASR training and no results * update the PR#1374 (feedback from @csukuangfj) - fixing .py headers and docstrings - removing BUT specific parts of `prepare.sh` - adding assert `num_jobs >= num_workers` to `compute_fbank.py` - narrowing list of languages (let's limit to ASR sets with transcripts for now) - added links to `README.md` - extending `text_from_manifest.py` --- egs/voxpopuli/ASR/README.md | 38 +++ egs/voxpopuli/ASR/local/compute_fbank.py | 248 +++++++++++++++++ .../ASR/local/compute_fbank_musan.py | 1 + .../ASR/local/display_manifest_statistics.py | 56 ++++ .../duration_from_supervision_manifest.py | 93 +++++++ egs/voxpopuli/ASR/local/filter_cuts.py | 1 + egs/voxpopuli/ASR/local/prepare_lang_bpe.py | 1 + .../ASR/local/preprocess_voxpopuli.py | 178 ++++++++++++ .../ASR/local/separate_punctuation.py | 130 +++++++++ egs/voxpopuli/ASR/local/text_from_manifest.py | 54 ++++ egs/voxpopuli/ASR/local/train_bpe_model.py | 1 + .../ASR/local/uppercase_begin_of_sentence.py | 113 ++++++++ .../ASR/local/validate_bpe_lexicon.py | 1 + .../ASR/local/validate_cutset_manifest.py | 123 +++++++++ egs/voxpopuli/ASR/prepare.sh | 257 ++++++++++++++++++ egs/voxpopuli/ASR/shared | 1 + 16 files changed, 1296 insertions(+) create mode 100644 egs/voxpopuli/ASR/README.md create mode 100755 egs/voxpopuli/ASR/local/compute_fbank.py create mode 120000 egs/voxpopuli/ASR/local/compute_fbank_musan.py create mode 100755 egs/voxpopuli/ASR/local/display_manifest_statistics.py create mode 100755 egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py create mode 120000 egs/voxpopuli/ASR/local/filter_cuts.py create mode 120000 egs/voxpopuli/ASR/local/prepare_lang_bpe.py create mode 100755 egs/voxpopuli/ASR/local/preprocess_voxpopuli.py create mode 100755 egs/voxpopuli/ASR/local/separate_punctuation.py create mode 100755 egs/voxpopuli/ASR/local/text_from_manifest.py create mode 120000 egs/voxpopuli/ASR/local/train_bpe_model.py create mode 100755 egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py create mode 120000 egs/voxpopuli/ASR/local/validate_bpe_lexicon.py create mode 100755 egs/voxpopuli/ASR/local/validate_cutset_manifest.py create mode 100755 egs/voxpopuli/ASR/prepare.sh create mode 120000 egs/voxpopuli/ASR/shared diff --git a/egs/voxpopuli/ASR/README.md b/egs/voxpopuli/ASR/README.md new file mode 100644 index 000000000..92aa26464 --- /dev/null +++ b/egs/voxpopuli/ASR/README.md @@ -0,0 +1,38 @@ +# Readme + +This recipe contains data preparation for the +[VoxPopuli](https://github.com/facebookresearch/voxpopuli) dataset +[(pdf)](https://aclanthology.org/2021.acl-long.80.pdf). +At the moment, without model training. + + +## audio per language + +| language | Size | Hrs. untranscribed | Hrs. transcribed | +|----------|--------|--------------------|------------------| +| bg | 295G | 17.6K | - | +| cs | 308G | 18.7K | 62 | +| da | 233G | 13.6K | - | +| de | 379G | 23.2K | 282 | +| el | 305G | 17.7K | - | +| en | 382G | 24.1K | 543 | +| es | 362G | 21.4K | 166 | +| et | 179G | 10.6K | 3 | +| fi | 236G | 14.2K | 27 | +| fr | 376G | 22.8K | 211 | +| hr | 132G | 8.1K | 43 | +| hu | 297G | 17.7K | 63 | +| it | 361G | 21.9K | 91 | +| lt | 243G | 14.4K | 2 | +| lv | 217G | 13.1K | - | +| mt | 147G | 9.1K | - | +| nl | 322G | 19.0K | 53 | +| pl | 348G | 21.2K | 111 | +| pt | 300G | 17.5K | - | +| ro | 296G | 17.9K | 89 | +| sk | 201G | 12.1K | 35 | +| sl | 190G | 11.3K | 10 | +| sv | 272G | 16.3K | - | +| | | | | +| total | 6.3T | 384K | 1791 | + diff --git a/egs/voxpopuli/ASR/local/compute_fbank.py b/egs/voxpopuli/ASR/local/compute_fbank.py new file mode 100755 index 000000000..b63e51f29 --- /dev/null +++ b/egs/voxpopuli/ASR/local/compute_fbank.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of VoxPopuli dataset. + +Usage example: + + python3 ./local/compute_fbank.py \ + --src-dir data/fbank --output-dir data/fbank \ + --num-jobs 100 --num-workers 25 \ + --prefix "voxpopuli-${task}-${lang}" \ + --dataset train \ + --trim-to-supervisions True \ + --speed-perturb True + +It looks for raw CutSet in the directory data/fbank +located at: `{src_dir}/{prefix}_cuts_{dataset}_raw.jsonl.gz`. + +The generated fbank features are saved in `data/fbank/{prefix}-{dataset}_feats` +and CutSet manifest stored in `data/fbank/{prefix}_cuts_{dataset}.jsonl.gz`. + +Typically, the number of workers is smaller than number of jobs +(see --num-jobs 100 --num-workers 25 in the example). +And, the number of jobs should be at least the number of workers (it's checked). +""" + +import argparse +import logging +import multiprocessing +import os +from concurrent.futures import ProcessPoolExecutor +from pathlib import Path + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + is_caching_enabled, + set_caching_enabled, +) + +from icefall.utils import str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + parser.add_argument( + "--src-dir", + type=str, + help="""Folder with the input manifest files.""", + default="data/manifests", + ) + parser.add_argument( + "--output-dir", + type=str, + help="""Folder with the output manifests (cuts) and feature files.""", + default="data/fbank", + ) + + parser.add_argument( + "--prefix", + type=str, + help="""Prefix of the manifest files.""", + default="", + ) + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank (train,test,dev).""", + default=None, + ) + + parser.add_argument( + "--num-jobs", + type=int, + help="""Number of jobs (i.e. files with extracted features)""", + default=50, + ) + parser.add_argument( + "--num-workers", + type=int, + help="""Number of parallel workers""", + default=10, + ) + parser.add_argument( + "--speed-perturb", + type=str2bool, + default=False, + help="""Enable speed perturbation for the set.""", + ) + parser.add_argument( + "--trim-to-supervisions", + type=str2bool, + default=False, + help="""Apply `trim-to-supervision` to cut set.""", + ) + + return parser.parse_args() + + +def compute_fbank_features(args: argparse.Namespace): + set_caching_enabled(True) # lhotse + + src_dir = Path(args.src_dir) + output_dir = Path(args.output_dir) + num_jobs = args.num_jobs + num_workers = min(args.num_workers, os.cpu_count()) + num_mel_bins = 80 + + bpe_model = args.bpe_model + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + prefix = args.prefix # "ELEF_TRAIN" + dataset = args.dataset + suffix = "jsonl.gz" + + cuts_raw_filename = Path(f"{src_dir}/{prefix}_cuts_{dataset}_raw.{suffix}") + cuts_raw = CutSet.from_file(cuts_raw_filename) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + cuts_filename = Path(f"{prefix}_cuts_{dataset}.{suffix}") + if (output_dir / cuts_filename).is_file(): + logging.info(f"{output_dir/cuts_filename} already exists - skipping.") + return + + logging.info(f"Processing {output_dir/cuts_filename}") + cut_set = cuts_raw + + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + + if args.speed_perturb: + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + + if args.trim_to_supervisions: + logging.info(f"About to `trim_to_supervisions()` {output_dir / cuts_filename}") + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) + else: + logging.info( + "Not doing `trim_to_supervisions()`, " + "to enable use --trim-to-supervision=True" + ) + + cut_set = cut_set.to_eager() # disallow lazy evaluation (sorting requires it) + cut_set = cut_set.sort_by_recording_id() # enhances AudioCache hit rate + + # We typically use `num_jobs=100, num_workers=20` + # - this is helpful for large databases + # - both values are configurable externally + assert num_jobs >= num_workers, (num_jobs, num_workers) + executor = ProcessPoolExecutor( + max_workers=num_workers, + mp_context=multiprocessing.get_context("spawn"), + initializer=set_caching_enabled, + initargs=(is_caching_enabled(),), + ) + + logging.info( + f"executor {executor} : num_workers {num_workers}, num_jobs {num_jobs}" + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir / prefix}-{dataset}_feats", + num_jobs=num_jobs, + executor=executor, + storage_type=LilcomChunkyWriter, + ) + + # correct small deviations of duration, caused by speed-perturbation + for cut in cut_set: + assert len(cut.supervisions) == 1, (len(cut.supervisions), cut.id) + duration_difference = abs(cut.supervisions[0].duration - cut.duration) + tolerance = 0.02 # 20ms + if duration_difference == 0.0: + pass + elif duration_difference <= tolerance: + logging.info( + "small mismatch of the supervision duration " + f"(Δt = {duration_difference*1000}ms), " + f"correcting : cut.duration {cut.duration} -> " + f"supervision {cut.supervisions[0].duration}" + ) + cut.supervisions[0].duration = cut.duration + else: + logging.error( + "mismatch of cut/supervision duration " + f"(Δt = {duration_difference*1000}ms) : " + f"cut.duration {cut.duration}, " + f"supervision {cut.supervisions[0].duration}" + ) + raise ValueError( + "mismatch of cut/supervision duration " + f"(Δt = {duration_difference*1000}ms)" + ) + + # store the cutset + logging.info(f"storing CutSet to : `{output_dir / cuts_filename}`") + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + logging.info(vars(args)) + + compute_fbank_features(args) diff --git a/egs/voxpopuli/ASR/local/compute_fbank_musan.py b/egs/voxpopuli/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/voxpopuli/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/display_manifest_statistics.py b/egs/voxpopuli/ASR/local/display_manifest_statistics.py new file mode 100755 index 000000000..36c99e126 --- /dev/null +++ b/egs/voxpopuli/ASR/local/display_manifest_statistics.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +Usage example: + python3 ./local/display_manifest_statistics.py data/fbank/*_cuts*.jsonl.gz + +See the function `remove_short_and_long_utt()` in transducer/train.py +for usage. + +""" + +import argparse + +from lhotse import load_manifest_lazy + + +def get_args(): + parser = argparse.ArgumentParser("Compute statistics for 'cuts' .jsonl.gz") + + parser.add_argument( + "filename", + help="data/fbank/imported_cuts_bison-train_trim.jsonl.gz", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + cuts = load_manifest_lazy(args.filename) + cuts.describe() + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py b/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py new file mode 100755 index 000000000..957267fe8 --- /dev/null +++ b/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# Copyright 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script computes durations of datasets from +the SupervisionSet manifests. + +Usage example: + + python3 ./local/duration_from_supervision_manifest.py \ + data/manifest/*_superivions*.jsonl.gz +""" + +import argparse +import gzip +import json +import logging +import re +import sys + + +def get_args(): + parser = argparse.ArgumentParser( + "Read the raw text from the 'supervisions.jsonl.gz'" + ) + + parser.add_argument( + "filename", + help="supervisions.jsonl.gz", + nargs="+", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + logging.info(vars(args)) + + total_duration = 0.0 + total_n_utts = 0 + + for fname in args.filename: + if fname == "-": + fd = sys.stdin + elif re.match(r".*\.jsonl\.gz$", fname): + fd = gzip.open(fname, mode="r") + else: + fd = open(fname, mode="r") + + fname_duration = 0.0 + n_utts = 0 + for line in fd: + js = json.loads(line) + fname_duration += js["duration"] + n_utts += 1 + + print( + f"Duration: {fname_duration/3600:7.2f} hours " + f"(eq. {fname_duration:7.0f} seconds, {n_utts} utts): {fname}" + ) + + if fd != sys.stdin: + fd.close() + + total_duration += fname_duration + total_n_utts += n_utts + + print( + f"Total duration: {total_duration/3600:7.2f} hours " + f"(eq. {total_duration:7.0f} seconds)" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/voxpopuli/ASR/local/filter_cuts.py b/egs/voxpopuli/ASR/local/filter_cuts.py new file mode 120000 index 000000000..27aca1729 --- /dev/null +++ b/egs/voxpopuli/ASR/local/filter_cuts.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/prepare_lang_bpe.py b/egs/voxpopuli/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/voxpopuli/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py b/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py new file mode 100755 index 000000000..4032537db --- /dev/null +++ b/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# 2023 Brno University of Technology (author: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Preprocess the database. +- Convert RecordingSet and SupervisionSet to CutSet. +- Apply text normalization to the transcripts. + - We take renormalized `orig_text` as `text` transcripts. + - The text normalization is separating punctuation from words. + - Also we put capital letter to the beginning of a sentence. + +The script is inspired in: + `egs/commonvoice/ASR/local/preprocess_commonvoice.py` + +Usage example: + python3 ./local/preprocess_voxpopuli.py \ + --task asr --lang en + +""" + +import argparse +import logging +from pathlib import Path +from typing import Optional + +from lhotse import CutSet +from lhotse.recipes.utils import read_manifests_if_cached + +# from local/ +from separate_punctuation import separate_punctuation +from uppercase_begin_of_sentence import UpperCaseBeginOfSentence + +from icefall.utils import str2bool + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + default=None, + ) + + parser.add_argument( + "--task", + type=str, + help="""Task of VoxPopuli""", + default="asr", + ) + + parser.add_argument( + "--lang", + type=str, + help="""Language of VoxPopuli""", + required=True, + ) + + parser.add_argument( + "--use-original-text", + type=str2bool, + help="""Use 'original_text' from the annoattaion file, + otherwise 'normed_text' will be used + (see `data/manifests/${task}_${lang}.tsv.gz`). + """, + default=False, + ) + + return parser.parse_args() + + +def normalize_text(utt: str) -> str: + utt = UpperCaseBeginOfSentence().process_line_text(separate_punctuation(utt)) + return utt + + +def preprocess_voxpopuli( + task: str, + language: str, + dataset: Optional[str] = None, + use_original_text: bool = False, +): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + output_dir.mkdir(exist_ok=True) + + if dataset is None: + dataset_parts = ( + "dev", + "test", + "train", + ) + else: + dataset_parts = dataset.split(" ", -1) + + logging.info("Loading manifest") + prefix = f"voxpopuli-{task}-{language}" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + suffix=suffix, + prefix=prefix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + for partition, m in manifests.items(): + logging.info(f"Processing {partition}") + raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}" + if raw_cuts_path.is_file(): + logging.info(f"{partition} already exists - skipping") + continue + + if use_original_text: + logging.info("Using 'original_text' from the annotation file.") + logging.info(f"Normalizing text in {partition}") + for sup in m["supervisions"]: + # `orig_text` includes punctuation and true-case + orig_text = str(sup.custom["orig_text"]) + # we replace `text` by normalized `orig_text` + sup.text = normalize_text(orig_text) + else: + logging.info("Using 'normed_text' from the annotation file.") + + # remove supervisions with empty 'text' + m["supervisions"] = m["supervisions"].filter(lambda sup: len(sup.text) > 0) + + # Create cut manifest with long-recordings. + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ).resample(16000) + + # Store the cut set incl. the resampling. + logging.info(f"Saving to {raw_cuts_path}") + cut_set.to_file(raw_cuts_path) + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + preprocess_voxpopuli( + task=args.task, + language=args.lang, + dataset=args.dataset, + use_original_text=args.use_original_text, + ) + logging.info("Done") + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/separate_punctuation.py b/egs/voxpopuli/ASR/local/separate_punctuation.py new file mode 100755 index 000000000..706d6fcd5 --- /dev/null +++ b/egs/voxpopuli/ASR/local/separate_punctuation.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script chops the punctuation as standalone tokens. +Example: + input: "This is fine. Yes, you are right." + output: "This is fine . Yes , you are right ." + +The script also handles exceptions in a hard-coded fashion. + +(same functionality could be done with `nltk.tokenize.word_tokenize()`, + but that would be an extra dependency) + +It can be used as a module, or as an executable script. + +Usage example #1: + `from separate_punctuation import separate_punctuation` + +Usage example #2: +``` + python3 ./local/separate_punctuation.py \ + --ignore-columns 1 \ + < ${kaldi_data}/text +``` +""" + +import re +import sys +from argparse import ArgumentParser + + +def separate_punctuation(text: str) -> str: + """ + Text filtering function for separating punctuation. + + Example: + input: "This is fine. Yes, you are right." + output: "This is fine . Yes , you are right ." + + The exceptions for which the punctuation is + not splitted are hard-coded. + """ + + # remove non-desired punctuation symbols + text = re.sub('["„“«»]', "", text) + + # separate [,.!?;] punctuation from words by space + text = re.sub(r"(\w)([,.!?;])", r"\1 \2", text) + text = re.sub(r"([,.!?;])(\w)", r"\1 \2", text) + + # split to tokens + tokens = text.split() + tokens_out = [] + + # re-join the special cases of punctuation + for ii, tok in enumerate(tokens): + # no rewriting for 1st and last token + if ii > 0 and ii < len(tokens) - 1: + # **RULES ADDED FOR CZECH COMMON VOICE** + + # fix "27 . dubna" -> "27. dubna", but keep punctuation separate, + if tok == "." and tokens[ii - 1].isdigit() and tokens[ii + 1].islower(): + tokens_out[-1] = tokens_out[-1] + "." + continue + + # fix "resp . pak" -> "resp. pak" + if tok == "." and tokens[ii - 1].isalpha() and tokens[ii + 1].islower(): + tokens_out[-1] = tokens_out[-1] + "." + continue + + # **RULES ADDED FOR ENGLISH COMMON VOICE** + + # fix "A ." -> "A." + if tok == "." and re.match(r"^[A-Z]S", tokens[ii - 1]): + tokens_out[-1] = tokens_out[-1] + "." + continue + + # fix "Mr ." -> "Mr." + exceptions = set(["Mr", "Mrs", "Ms"]) + if tok == "." and tokens[ii - 1] in exceptions: + tokens_out[-1] = tokens_out[-1] + "." + continue + + tokens_out.append(tok) + + return " ".join(tokens_out) + + +def get_args(): + parser = ArgumentParser( + description="Separate punctuation from words: 'hello.' -> 'hello .'" + ) + parser.add_argument( + "--ignore-columns", type=int, default=1, help="skip number of initial columns" + ) + return parser.parse_args() + + +def main(): + args = get_args() + + max_split = args.ignore_columns + + while True: + line = sys.stdin.readline() + if not line: + break + + *key, text = line.strip().split(maxsplit=max_split) + text_norm = separate_punctuation(text) + + print(" ".join(key), text_norm) + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/text_from_manifest.py b/egs/voxpopuli/ASR/local/text_from_manifest.py new file mode 100755 index 000000000..d9ab53b5a --- /dev/null +++ b/egs/voxpopuli/ASR/local/text_from_manifest.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Print the text contained in `supervisions.jsonl.gz` or `cuts.jsonl.gz`. + +Usage example: + python3 ./local/text_from_manifest.py \ + data/manifests/voxpopuli-asr-en_supervisions_dev.jsonl.gz +""" + +import argparse +import gzip +import json + + +def get_args(): + parser = argparse.ArgumentParser( + "Read the raw text from the 'supervisions.jsonl.gz'" + ) + parser.add_argument("filename", help="supervisions.jsonl.gz") + return parser.parse_args() + + +def main(): + args = get_args() + + with gzip.open(args.filename, mode="r") as fd: + for line in fd: + js = json.loads(line) + if "text" in js: + print(js["text"]) # supervisions.jsonl.gz + elif "supervisions" in js: + for s in js["supervisions"]: + print(s["text"]) # cuts.jsonl.gz + else: + raise Exception(f"Unknown jsonl format of {args.filename}") + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/train_bpe_model.py b/egs/voxpopuli/ASR/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/voxpopuli/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py b/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py new file mode 100755 index 000000000..8e9de905f --- /dev/null +++ b/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script introduces initial capital letter at the beginning of a sentence. +It can be used as a module, or as an executable script. + +Usage example #1: + `from uppercase_begin_of_sentence import UpperCaseBeginOfSentence` + +Usage example #2: +``` + python3 ./local/uppercase_begin_of_sentence.py \ + --ignore-columns 1 \ + < ${kaldi_data}/text +``` +""" + +import re +import sys +from argparse import ArgumentParser + + +class UpperCaseBeginOfSentence: + """ + This class introduces initial capital letter at the beginning of a sentence. + Capital letter is used, if previous symbol was punctuation token from + `set([".", "!", "?"])`. + + The punctuation as previous token is memorized also across + `process_line_text()` calls. + """ + + def __init__(self): + # The 1st word will have Title-case + # This variable transfers context from previous line + self.prev_token_is_punct = True + + def process_line_text(self, line_text: str) -> str: + """ + It is assumed that punctuation in `line_text` was already separated, + example: "This is fine . Yes , you are right ." + """ + + words = line_text.split() + punct_set = set([".", "!", "?"]) + + for ii, w in enumerate(words): + # punctuation ? + if w in punct_set: + self.prev_token_is_punct = True + continue + + # change case of word... + if self.prev_token_is_punct: + if re.match("<", w): + continue # skip + # apply Title-case only on lowercase words. + if w.islower(): + words[ii] = w.title() + # change state + self.prev_token_is_punct = False + + line_text_uc = " ".join(words) + + return line_text_uc + + +def get_args(): + parser = ArgumentParser( + description="Put upper-case at the beginning of a sentence." + ) + parser.add_argument( + "--ignore-columns", type=int, default=4, help="skip number of initial columns" + ) + return parser.parse_args() + + +def main(): + args = get_args() + + uc_bos = UpperCaseBeginOfSentence() + max_split = args.ignore_columns + + while True: + line = sys.stdin.readline() + if not line: + break + line = line.strip() + + if len(line.split()) > 1: + *key, text = line.strip().split(maxsplit=max_split) # parse, + text_uc = uc_bos.process_line_text(text) # process, + print(" ".join(key), text_uc) # print, + else: + print(line) + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py b/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/validate_cutset_manifest.py b/egs/voxpopuli/ASR/local/validate_cutset_manifest.py new file mode 100755 index 000000000..4659aa9cd --- /dev/null +++ b/egs/voxpopuli/ASR/local/validate_cutset_manifest.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut +- Supervision time bounds are within Cut time bounds +- Duration of Cut and Superivion are equal + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz + +(Based on: `librispeech/ASR/local/validate_manifest.py`) +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut +from lhotse.dataset.speech_recognition import validate_for_asr + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "cutset_manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def validate_one_supervision_per_cut(c: Cut): + if len(c.supervisions) != 1: + raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") + + +def validate_supervision_and_cut_time_bounds(c: Cut): + tol = 2e-3 # same tolerance as in 'validate_for_asr()' + s = c.supervisions[0] + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + if s.start < -tol: + raise ValueError( + f"{c.id}: Supervision start time {s.start} must not be negative." + ) + if s.start > tol: + raise ValueError( + f"{c.id}: Supervision start time {s.start} " + "is not at the beginning of the Cut. " + "Please apply `lhotse cut trim-to-supervisions`." + ) + if c.start + s.end > c.end + tol: + raise ValueError( + f"{c.id}: Supervision end time {c.start+s.end} is larger " + f"than cut end time {c.end}" + ) + + if s.duration != c.duration: + raise ValueError( + f"{c.id}: Cut duration {c.duration} and supervision duration " + f"{s.duration} must be the same.\n" + f"The difference causes problems in the training code : " + f"+/- 1 frame in `x`, `x_lens` in `Zipformer::forward()`.\n" + f"Did you forget to apply `trim_to_supervisions()` ?" + ) + + +def main(): + args = get_args() + + manifest = args.cutset_manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet) + + try: + for c in cut_set: + validate_one_supervision_per_cut(c) + validate_supervision_and_cut_time_bounds(c) + + # Validation from K2 training + # - checks supervision start is 0 + # - checks supervision.duration is not longer than cut.duration + # - there is tolerance 2ms + validate_for_asr(cut_set) + except BaseException as e: + logging.error(str(e)) + raise + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/voxpopuli/ASR/prepare.sh b/egs/voxpopuli/ASR/prepare.sh new file mode 100755 index 000000000..7cddad756 --- /dev/null +++ b/egs/voxpopuli/ASR/prepare.sh @@ -0,0 +1,257 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -euxo pipefail + +nj=20 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/voxpopuli/raw_audios/$lang/$year +# This directory contains *.ogg files with audio downloaded and extracted from archives: +# https://dl.fbaipublicfiles.com/voxpopuli/audios/${lang}_${year}.tar +# +# - Note: the voxpopuli transcripts are downloaded to a ${tmp} folder +# as part of `lhotse prepare voxpopuli` from: +# https://dl.fbaipublicfiles.com/voxpopuli/annotations/asr/asr_${lang}.tsv.gz +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=$PWD/download +#dl_dir=/mnt/matylda6/szoke/EU-ASR/DATA # BUT + +musan_dir=${dl_dir}/musan +#musan_dir=/mnt/matylda2/data/MUSAN # BUT + +# Choose value from ASR_LANGUAGES: +# +# [ "en", "de", "fr", "es", "pl", "it", "ro", "hu", "cs", "nl", "fi", "hr", +# "sk", "sl", "et", "lt" ] +# +# See ASR_LANGUAGES in: +# https://github.com/lhotse-speech/lhotse/blob/c5f26afd100885b86e4244eeb33ca1986f3fa923/lhotse/recipes/voxpopuli.py#L54C4-L54C4 +lang=en + +task=asr + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/${lang}/lang_bpe_xxx, +# data/${lang}/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + # 1000 + 500 +) + +# All files generated by this script are saved in "data/${lang}". +# You can safely remove "data/${lang}" and rerun this script to regenerate it. +mkdir -p data/${lang} + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" +log "musan_dir: $musan_dir" +log "task: $task, lang: $lang" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/$release, + # you can create a symlink + # + # ln -sfv /path/to/$release $dl_dir/$release + # + if [ ! -d $dl_dir/voxpopuli/raw_audios/${lang} ]; then + lhotse download voxpopuli --subset $lang $dl_dir/voxpopuli + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $musan_dir/musan ]; then + lhotse download musan $musan_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare VoxPopuli manifest" + # We assume that you have downloaded the VoxPopuli corpus + # to $dl_dir/voxpopuli + if [ ! -e data/manifests/.voxpopuli-${task}-${lang}.done ]; then + # Warning : it requires Internet connection (it downloads transcripts to ${tmpdir}) + lhotse prepare voxpopuli --task asr --lang $lang -j $nj $dl_dir/voxpopuli data/manifests + touch data/manifests/.voxpopuli-${task}-${lang}.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + #lhotse prepare musan $dl_dir/musan data/manifests + lhotse prepare musan $musan_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Preprocess VoxPopuli manifest" + mkdir -p data/fbank + if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-preprocess_complete ]; then + # recordings + supervisions -> cutset + ./local/preprocess_voxpopuli.py --task $task --lang $lang \ + --use-original-text True + touch data/fbank/.voxpopuli-${task}-${lang}-preprocess_complete + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for dev and test subsets of VoxPopuli" + mkdir -p data/fbank + for dataset in "dev" "test"; do + if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done ]; then + ./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \ + --num-jobs 50 --num-workers ${nj} \ + --prefix "voxpopuli-${task}-${lang}" \ + --dataset ${dataset} \ + --trim-to-supervisions True + touch data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done + fi + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute fbank for train set of VoxPopuli" + if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-train.done ]; then + ./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \ + --num-jobs 100 --num-workers ${nj} \ + --prefix "voxpopuli-${task}-${lang}" \ + --dataset train \ + --trim-to-supervisions True \ + --speed-perturb True + touch data/fbank/.voxpopuli-${task}-${lang}-train.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Validate fbank manifests for VoxPopuli" + for dataset in "dev" "test" "train"; do + mkdir -p data/fbank/log/ + ./local/validate_cutset_manifest.py \ + data/fbank/voxpopuli-asr-en_cuts_${dataset}.jsonl.gz \ + 2>&1 | tee data/fbank/log/validate_voxpopuli-asr-en_cuts_${dataset}.log + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size}_${lang} + mkdir -p $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + file=$( + find "data/fbank/voxpopuli-${task}-${lang}_cuts_train.jsonl.gz" + ) + local/text_from_manifest.py $file >$lang_dir/transcript_words.txt + # gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt + + # Ensure space only appears once + #sed -i 's/\t/ /g' $lang_dir/transcript_words.txt + #sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/words.txt ]; then + cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ + | sort -u | sed '/^$/d' > $lang_dir/words.txt + (echo '!SIL'; echo ''; echo ''; ) | + cat - $lang_dir/words.txt | sort | uniq | awk ' + BEGIN { + print " 0"; + } + { + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + printf("%s %d\n", $1, NR); + } + END { + printf("#0 %d\n", NR+1); + printf(" %d\n", NR+2); + printf(" %d\n", NR+3); + }' > $lang_dir/words || exit 1; + mv $lang_dir/words $lang_dir/words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi diff --git a/egs/voxpopuli/ASR/shared b/egs/voxpopuli/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/voxpopuli/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file From 666d69b20d53796420593d99b0c0d6e9cd2212cc Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 17 Nov 2023 18:12:59 +0800 Subject: [PATCH 102/113] Rename train2.py to avoid confusion (#1386) --- .github/scripts/run-multi-zh_hans-zipformer.sh | 4 +++- egs/aishell/ASR/prepare.sh | 5 ++--- .../{train2.py => do_not_use_it_directly.py} | 1 + egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py | 2 +- .../{train2.py => do_not_use_it_directly.py} | 1 + .../ASR/pruned_transducer_stateless7_streaming/README.md | 4 ++-- .../{train2.py => do_not_use_it_directly.py} | 1 + .../{train2.py => do_not_use_it_directly.py} | 1 + .../export-for-ncnn.py | 2 +- .../{train2.py => do_not_use_it_directly.py} | 1 + .../conv_emformer_transducer_stateless2/export-for-ncnn.py | 2 +- .../ASR/conv_emformer_transducer_stateless2/export-onnx.py | 2 +- .../ASR/pruned_transducer_stateless7_streaming/README.md | 4 ++-- .../{train2.py => do_not_use_it_directly.py} | 1 + .../export-for-ncnn-zh.py | 2 +- .../export-for-ncnn.py | 2 +- .../do_not_use_it_directly.py | 1 + .../export-for-ncnn.py | 2 +- .../pruned_transducer_stateless7_streaming_multi/train2.py | 1 - 19 files changed, 23 insertions(+), 16 deletions(-) rename egs/aishell/ASR/pruned_transducer_stateless7/{train2.py => do_not_use_it_directly.py} (99%) rename egs/aishell/ASR/pruned_transducer_stateless7_streaming/{train2.py => do_not_use_it_directly.py} (99%) rename egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/{train2.py => do_not_use_it_directly.py} (99%) rename egs/csj/ASR/pruned_transducer_stateless7_streaming/{train2.py => do_not_use_it_directly.py} (99%) rename egs/librispeech/ASR/conv_emformer_transducer_stateless2/{train2.py => do_not_use_it_directly.py} (99%) rename egs/librispeech/ASR/pruned_transducer_stateless7_streaming/{train2.py => do_not_use_it_directly.py} (99%) create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py delete mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py diff --git a/.github/scripts/run-multi-zh_hans-zipformer.sh b/.github/scripts/run-multi-zh_hans-zipformer.sh index dd32a94f8..cbd86a4d3 100755 --- a/.github/scripts/run-multi-zh_hans-zipformer.sh +++ b/.github/scripts/run-multi-zh_hans-zipformer.sh @@ -51,6 +51,8 @@ for method in modified_beam_search fast_beam_search; do $repo/test_wavs/DEV_T0000000002.wav done +rm -rf $repo + log "==== Test icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 ====" repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/ @@ -92,4 +94,4 @@ for method in modified_beam_search fast_beam_search; do $repo/test_wavs/DEV_T0000000000.wav \ $repo/test_wavs/DEV_T0000000001.wav \ $repo/test_wavs/DEV_T0000000002.wav -done \ No newline at end of file +done diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index d36dc5ed3..9f73a2073 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -261,10 +261,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then fi if [ ! -f $lang_char_dir/HLG.fst ]; then - lang_phone_dir=data/lang_phone ./local/prepare_lang_fst.py \ - --lang-dir $lang_phone_dir \ - --ngram-G ./data/lm/G_3_gram.fst.txt + --lang-dir $lang_char_dir \ + --ngram-G ./data/lm/G_3_gram_char.fst.txt fi fi diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py similarity index 99% rename from egs/aishell/ASR/pruned_transducer_stateless7/train2.py rename to egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py index 057af297f..6027273b2 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py @@ -1234,6 +1234,7 @@ def scan_pessimistic_batches_for_oom( def main(): + raise RuntimeError("Please don't use this file directly!") parser = get_parser() AsrDataModule.add_arguments(parser) args = parser.parse_args() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py index 2a9fc57d5..39d988cd0 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py @@ -56,7 +56,7 @@ import torch.nn as nn from decoder2 import Decoder from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from zipformer import Zipformer from icefall.checkpoint import ( diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py similarity index 99% rename from egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py rename to egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 88eb34104..3c13c19c6 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -1233,6 +1233,7 @@ def scan_pessimistic_batches_for_oom( def main(): + raise RuntimeError("Please don't use this file directly!") parser = get_parser() AishellAsrDataModule.add_arguments(parser) args = parser.parse_args() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md index 991875aaa..6c20bab2c 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md @@ -4,6 +4,6 @@ See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer [./emformer.py](./emformer.py) and [./train.py](./train.py) are basically the same as -[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py). -The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py) +[./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py). +The only purpose of [./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py) is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn). diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py similarity index 99% rename from egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py rename to egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index c09c9537c..61a3f27db 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -1237,6 +1237,7 @@ def scan_pessimistic_batches_for_oom( def main(): + raise RuntimeError("Please don't use this file directly!") parser = get_parser() CommonVoiceAsrDataModule.add_arguments(parser) args = parser.parse_args() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py similarity index 99% rename from egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py rename to egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 4c866ddd8..acde72d80 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -1274,6 +1274,7 @@ def scan_pessimistic_batches_for_oom( def main(): + raise RuntimeError("Please don't use this file directly!") parser = get_parser() CSJAsrDataModule.add_arguments(parser) Tokenizer.add_arguments(parser) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py index ebdb596a5..b210430c6 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -72,7 +72,7 @@ from pathlib import Path import torch from scaling_converter import convert_scaled_to_non_scaled from tokenizer import Tokenizer -from train2 import add_model_arguments, get_params, get_transducer_model +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py similarity index 99% rename from egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py rename to egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py index 420dc1065..d614f0914 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py @@ -1099,6 +1099,7 @@ def scan_pessimistic_batches_for_oom( def main(): + raise RuntimeError("Please don't use this file directly!") parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py index 85dbd4661..953f95c45 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py @@ -39,8 +39,8 @@ from pathlib import Path import k2 import torch +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py index ab046557f..1e59e0858 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py @@ -61,7 +61,7 @@ import torch.nn as nn from decoder import Decoder from emformer import Emformer from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md index d3691e647..0f3c63e75 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md @@ -4,7 +4,7 @@ See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer [./emformer.py](./emformer.py) and [./train.py](./train.py) are basically the same as -[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py). -The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py) +[./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py). +The only purpose of [./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py) is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn). diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py similarity index 99% rename from egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py rename to egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index aa6c0668a..cd26db6f3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -1234,6 +1234,7 @@ def scan_pessimistic_batches_for_oom( def main(): + raise RuntimeError("Please don't use this file directly!") parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py index 07de57a86..a7d06a5dd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -68,8 +68,8 @@ from pathlib import Path import k2 import torch +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py index 9a6b31268..8f2178b1d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -66,8 +66,8 @@ from pathlib import Path import k2 import torch +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py new file mode 120000 index 000000000..beeffaa03 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/do_not_use_it_directly.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py index 9a6b31268..8f2178b1d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py @@ -66,8 +66,8 @@ from pathlib import Path import k2 import torch +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py deleted file mode 120000 index 3c3280b68..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7_streaming/train2.py \ No newline at end of file From 11d816d174076ec9485ab8b1d36af2592514e348 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sat, 18 Nov 2023 18:47:55 +0800 Subject: [PATCH 103/113] Add cumstomized score for hotwords (#1385) * add custom score for each hotword * Add more comments * Fix deocde * fix style * minor fixes --- .../pruned_transducer_stateless7/decode.py | 2 +- .../decode.py | 2 +- .../pruned_transducer_stateless4/decode.py | 4 +- egs/librispeech/ASR/zipformer/decode.py | 4 +- .../pruned_transducer_stateless5/decode.py | 2 +- icefall/context_graph.py | 117 +++++++++++++----- 6 files changed, 92 insertions(+), 39 deletions(-) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py index be58c4e43..696eea906 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py @@ -641,7 +641,7 @@ def main(): contexts_text.append(line.strip()) contexts = graph_compiler.texts_to_ids(contexts_text) context_graph = ContextGraph(params.context_score) - context_graph.build(contexts) + context_graph.build([(c, 0.0) for c in contexts]) else: context_graph = None else: diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py index f5ae836fd..99110d6b6 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -686,7 +686,7 @@ def main(): contexts_text.append(line.strip()) contexts = graph_compiler.texts_to_ids(contexts_text) context_graph = ContextGraph(params.context_score) - context_graph.build(contexts) + context_graph.build([(c, 0.0) for c in contexts]) else: context_graph = None else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 524366068..5195a4ef6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -927,9 +927,9 @@ def main(): if os.path.exists(params.context_file): contexts = [] for line in open(params.context_file).readlines(): - contexts.append(line.strip()) + contexts.append((sp.encode(line.strip()), 0.0)) context_graph = ContextGraph(params.context_score) - context_graph.build(sp.encode(contexts)) + context_graph.build(contexts) else: context_graph = None else: diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 3531d657f..339e253e6 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -1001,9 +1001,9 @@ def main(): if os.path.exists(params.context_file): contexts = [] for line in open(params.context_file).readlines(): - contexts.append(line.strip()) + contexts.append((sp.encode(line.strip()), 0.0)) context_graph = ContextGraph(params.context_score) - context_graph.build(sp.encode(contexts)) + context_graph.build(contexts) else: context_graph = None else: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 36b8a4b67..d665f3364 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -868,7 +868,7 @@ def main(): contexts_text.append(line.strip()) contexts = graph_compiler.texts_to_ids(contexts_text) context_graph = ContextGraph(params.context_score) - context_graph.build(contexts) + context_graph.build([(c, 0.0) for c in contexts]) else: context_graph = None else: diff --git a/icefall/context_graph.py b/icefall/context_graph.py index 0b7c42c0b..b3d7972a8 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -84,6 +84,9 @@ class ContextGraph: context_score: The bonus score for each token(note: NOT for each word/phrase, it means longer word/phrase will have larger bonus score, they have to be matched though). + Note: This is just the default score for each token, the users can manually + specify the context_score for each word/phrase (i.e. different phrase might + have different token score). """ self.context_score = context_score self.num_nodes = 0 @@ -133,7 +136,7 @@ class ContextGraph: node.output_score += 0 if output is None else output.output_score queue.append(node) - def build(self, token_ids: List[List[int]]): + def build(self, token_ids: List[Tuple[List[int], float]]): """Build the ContextGraph from a list of token list. It first build a trie from the given token lists, then fill the fail arc for each trie node. @@ -142,26 +145,46 @@ class ContextGraph: Args: token_ids: - The given token lists to build the ContextGraph, it is a list of token list, - each token list contains the token ids for a word/phrase. The token id - could be an id of a char (modeling with single Chinese char) or an id - of a BPE (modeling with BPEs). + The given token lists to build the ContextGraph, it is a list of tuple of + token list and its customized score, the token list contains the token ids + for a word/phrase. The token id could be an id of a char + (modeling with single Chinese char) or an id of a BPE + (modeling with BPEs). The score is the total score for current token list, + 0 means using the default value (i.e. self.context_score). + + Note: The phrases would have shared states, the score of the shared states is + the maximum value among all the tokens sharing this state. """ - for tokens in token_ids: + for (tokens, score) in token_ids: node = self.root + # If has customized score using the customized token score, otherwise + # using the default score + context_score = ( + self.context_score if score == 0.0 else round(score / len(tokens), 2) + ) for i, token in enumerate(tokens): + node_next = {} if token not in node.next: self.num_nodes += 1 + node_id = self.num_nodes + token_score = context_score is_end = i == len(tokens) - 1 - node_score = node.node_score + self.context_score - node.next[token] = ContextState( - id=self.num_nodes, - token=token, - token_score=self.context_score, - node_score=node_score, - output_score=node_score if is_end else 0, - is_end=is_end, - ) + else: + # node exists, get the score of shared state. + token_score = max(context_score, node.next[token].token_score) + node_id = node.next[token].id + node_next = node.next[token].next + is_end = i == len(tokens) - 1 or node.next[token].is_end + node_score = node.node_score + token_score + node.next[token] = ContextState( + id=node_id, + token=token, + token_score=token_score, + node_score=node_score, + output_score=node_score if is_end else 0, + is_end=is_end, + ) + node.next[token].next = node_next node = node.next[token] self._fill_fail_output() @@ -343,7 +366,7 @@ class ContextGraph: return dot -if __name__ == "__main__": +def _test(queries, score): contexts_str = [ "S", "HE", @@ -355,9 +378,11 @@ if __name__ == "__main__": "THIS", "THEM", ] + + # test default score (1) contexts = [] for s in contexts_str: - contexts.append([ord(x) for x in s]) + contexts.append(([ord(x) for x in s], score)) context_graph = ContextGraph(context_score=1) context_graph.build(contexts) @@ -369,10 +394,28 @@ if __name__ == "__main__": context_graph.draw( title="Graph for: " + " / ".join(contexts_str), - filename="context_graph.pdf", + filename=f"context_graph_{score}.pdf", symbol_table=symbol_table, ) + for query, expected_score in queries.items(): + total_scores = 0 + state = context_graph.root + for q in query: + score, state = context_graph.forward_one_step(state, ord(q)) + total_scores += score + score, state = context_graph.finalize(state) + assert state.token == -1, state.token + total_scores += score + assert round(total_scores, 2) == expected_score, ( + total_scores, + expected_score, + query, + ) + + +if __name__ == "__main__": + # test default score queries = { "HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE" "HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE" @@ -384,17 +427,27 @@ if __name__ == "__main__": "DHRHISQ": 4, # "HIS", "S" "THEN": 2, # "HE" } - for query, expected_score in queries.items(): - total_scores = 0 - state = context_graph.root - for q in query: - score, state = context_graph.forward_one_step(state, ord(q)) - total_scores += score - score, state = context_graph.finalize(state) - assert state.token == -1, state.token - total_scores += score - assert total_scores == expected_score, ( - total_scores, - expected_score, - query, - ) + _test(queries, 0) + + # test custom score (5) + # S : 5 + # HE : 5 (2.5 + 2.5) + # SHE : 8.34 (5 + 1.67 + 1.67) + # SHELL : 10.34 (5 + 1.67 + 1.67 + 1 + 1) + # HIS : 5.84 (2.5 + 1.67 + 1.67) + # HERS : 7.5 (2.5 + 2.5 + 1.25 + 1.25) + # HELLO : 8 (2.5 + 2.5 + 1 + 1 + 1) + # THIS : 5 (1.25 + 1.25 + 1.25 + 1.25) + queries = { + "HEHERSHE": 35.84, # "HE", "HE", "HERS", "S", "SHE", "HE" + "HERSHE": 30.84, # "HE", "HERS", "S", "SHE", "HE" + "HISHE": 24.18, # "HIS", "S", "SHE", "HE" + "SHED": 18.34, # "S", "SHE", "HE" + "SHELF": 18.34, # "S", "SHE", "HE" + "HELL": 5, # "HE" + "HELLO": 13, # "HE", "HELLO" + "DHRHISQ": 10.84, # "HIS", "S" + "THEN": 5, # "HE" + } + + _test(queries, 5) From 238b45bea85deee1a07cfd0f55b485cc92f67135 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 23 Nov 2023 01:22:57 +0800 Subject: [PATCH 104/113] Libriheavy recipe (zipformer) (#1261) * initial commit for libriheavy * Data prepare pipeline * Fix train.py * Fix decode.py * Add results * minor fixes * black * black * Incorporate PR https://github.com/k2-fsa/icefall/pull/1269 --------- Co-authored-by: zr_jin --- egs/libriheavy/ASR/README.md | 6 + egs/libriheavy/ASR/RESULTS.md | 114 +- .../ASR/local/compute_fbank_libriheavy.py | 242 +++ .../ASR/local/compute_fbank_musan.py | 1 + egs/libriheavy/ASR/local/norm_text.py | 58 + egs/libriheavy/ASR/local/prepare_manifest.py | 47 + egs/libriheavy/ASR/local/train_bpe_model.py | 113 ++ egs/libriheavy/ASR/prepare.sh | 314 ++++ .../ASR/zipformer/asr_datamodule.py | 443 ++++++ egs/libriheavy/ASR/zipformer/beam_search.py | 1 + egs/libriheavy/ASR/zipformer/decode.py | 794 +++++++++ egs/libriheavy/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + egs/libriheavy/ASR/zipformer/export-onnx.py | 1 + egs/libriheavy/ASR/zipformer/export.py | 1 + .../ASR/zipformer/jit_pretrained.py | 1 + egs/libriheavy/ASR/zipformer/joiner.py | 1 + egs/libriheavy/ASR/zipformer/model.py | 1 + egs/libriheavy/ASR/zipformer/onnx_decode.py | 1 + .../ASR/zipformer/onnx_pretrained.py | 1 + egs/libriheavy/ASR/zipformer/optim.py | 1 + egs/libriheavy/ASR/zipformer/pretrained.py | 1 + egs/libriheavy/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_coverter.py | 1 + egs/libriheavy/ASR/zipformer/subsampling.py | 1 + .../ASR/zipformer/text_normalization.py | 50 + egs/libriheavy/ASR/zipformer/train.py | 1415 +++++++++++++++++ egs/libriheavy/ASR/zipformer/zipformer.py | 1 + requirements-ci.txt | 1 + requirements.txt | 1 + 30 files changed, 3613 insertions(+), 2 deletions(-) create mode 100644 egs/libriheavy/ASR/README.md create mode 100755 egs/libriheavy/ASR/local/compute_fbank_libriheavy.py create mode 120000 egs/libriheavy/ASR/local/compute_fbank_musan.py create mode 100755 egs/libriheavy/ASR/local/norm_text.py create mode 100755 egs/libriheavy/ASR/local/prepare_manifest.py create mode 100755 egs/libriheavy/ASR/local/train_bpe_model.py create mode 100755 egs/libriheavy/ASR/prepare.sh create mode 100644 egs/libriheavy/ASR/zipformer/asr_datamodule.py create mode 120000 egs/libriheavy/ASR/zipformer/beam_search.py create mode 100644 egs/libriheavy/ASR/zipformer/decode.py create mode 120000 egs/libriheavy/ASR/zipformer/decoder.py create mode 120000 egs/libriheavy/ASR/zipformer/encoder_interface.py create mode 120000 egs/libriheavy/ASR/zipformer/export-onnx.py create mode 120000 egs/libriheavy/ASR/zipformer/export.py create mode 120000 egs/libriheavy/ASR/zipformer/jit_pretrained.py create mode 120000 egs/libriheavy/ASR/zipformer/joiner.py create mode 120000 egs/libriheavy/ASR/zipformer/model.py create mode 120000 egs/libriheavy/ASR/zipformer/onnx_decode.py create mode 120000 egs/libriheavy/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/libriheavy/ASR/zipformer/optim.py create mode 120000 egs/libriheavy/ASR/zipformer/pretrained.py create mode 120000 egs/libriheavy/ASR/zipformer/scaling.py create mode 120000 egs/libriheavy/ASR/zipformer/scaling_coverter.py create mode 120000 egs/libriheavy/ASR/zipformer/subsampling.py create mode 100644 egs/libriheavy/ASR/zipformer/text_normalization.py create mode 100644 egs/libriheavy/ASR/zipformer/train.py create mode 120000 egs/libriheavy/ASR/zipformer/zipformer.py diff --git a/egs/libriheavy/ASR/README.md b/egs/libriheavy/ASR/README.md new file mode 100644 index 000000000..2498d017f --- /dev/null +++ b/egs/libriheavy/ASR/README.md @@ -0,0 +1,6 @@ +# Libriheavy: a 50,000 hours ASR corpus with punctuation casing and context + +Libriheavy is a labeled version of [Librilight](https://arxiv.org/pdf/1912.07875.pdf). Please refer to our repository [k2-fsa/libriheavy](https://github.com/k2-fsa/libriheavy) for more details. We also have a paper: *Libriheavy: a 50,000 hours ASR corpus with punctuation casing and context*, [Preprint available on arxiv](https://arxiv.org/abs/2309.08105). + + +See [RESULTS](./RESULTS.md) for the results for icefall recipes. diff --git a/egs/libriheavy/ASR/RESULTS.md b/egs/libriheavy/ASR/RESULTS.md index 4fbedad98..513bbf72e 100644 --- a/egs/libriheavy/ASR/RESULTS.md +++ b/egs/libriheavy/ASR/RESULTS.md @@ -1,6 +1,116 @@ -## Results +# Results -### Zipformer PromptASR (zipformer + PromptASR + BERT text encoder) +## zipformer (zipformer + pruned stateless transducer) + +See for more details. + +[zipformer](./zipformer) + +### Non-streaming + +#### Training on normalized text, i.e. Upper case without punctuation + +##### normal-scaled model, number of model parameters: 65805511, i.e., 65.81 M + +You can find a pretrained model, training logs at: + + +Note: The repository above contains three models trained on different subset of libriheavy exp(large set), exp_medium_subset(medium set), +exp_small_subset(small set). + +Results of models: + +| training set | decoding method | librispeech clean | librispeech other | libriheavy clean | libriheavy other | comment | +|---------------|---------------------|-------------------|-------------------|------------------|------------------|--------------------| +| small | greedy search | 4.19 | 9.99 | 4.75 | 10.25 |--epoch 90 --avg 20 | +| small | modified beam search| 4.05 | 9.89 | 4.68 | 10.01 |--epoch 90 --avg 20 | +| medium | greedy search | 2.39 | 4.85 | 2.90 | 6.6 |--epoch 60 --avg 20 | +| medium | modified beam search| 2.35 | 4.82 | 2.90 | 6.57 |--epoch 60 --avg 20 | +| large | greedy search | 1.67 | 3.32 | 2.24 | 5.61 |--epoch 16 --avg 3 | +| large | modified beam search| 1.62 | 3.36 | 2.20 | 5.57 |--epoch 16 --avg 3 | + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +python ./zipformer/train.py \ + --world-size 4 \ + --master-port 12365 \ + --exp-dir zipformer/exp \ + --num-epochs 60 \ # 16 for large; 90 for small + --lr-hours 15000 \ # 20000 for large; 5000 for small + --use-fp16 1 \ + --start-epoch 1 \ + --bpe-model data/lang_bpe_500/bpe.model \ + --max-duration 1000 \ + --subset medium +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in greedy_search modified_beam_search; do + ./zipformer/decode.py \ + --epoch 16 \ + --avg 3 \ + --exp-dir zipformer/exp \ + --max-duration 1000 \ + --causal 0 \ + --decoding-method $m +done +``` + +#### Training on full formatted text, i.e. with casing and punctuation + +##### normal-scaled model, number of model parameters: 66074067 , i.e., 66M + +You can find a pretrained model, training logs at: + + +Note: The repository above contains three models trained on different subset of libriheavy exp(large set), exp_medium_subset(medium set), +exp_small_subset(small set). + +Results of models: + +| training set | decoding method | libriheavy clean (WER) | libriheavy other (WER) | libriheavy clean (CER) | libriheavy other (CER) | comment | +|---------------|---------------------|-------------------|-------------------|------------------|------------------|--------------------| +| small | modified beam search| 13.04 | 19.54 | 4.51 | 7.90 |--epoch 88 --avg 41 | +| medium | modified beam search| 9.84 | 13.39 | 3.02 | 5.10 |--epoch 50 --avg 15 | +| large | modified beam search| 7.76 | 11.32 | 2.41 | 4.22 |--epoch 16 --avg 2 | + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +python ./zipformer/train.py \ + --world-size 4 \ + --master-port 12365 \ + --exp-dir zipformer/exp \ + --num-epochs 60 \ # 16 for large; 90 for small + --lr-hours 15000 \ # 20000 for large; 10000 for small + --use-fp16 1 \ + --train-with-punctuation 1 \ + --start-epoch 1 \ + --bpe-model data/lang_punc_bpe_756/bpe.model \ + --max-duration 1000 \ + --subset medium +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in greedy_search modified_beam_search; do + ./zipformer/decode.py \ + --epoch 16 \ + --avg 3 \ + --exp-dir zipformer/exp \ + --max-duration 1000 \ + --causal 0 \ + --decoding-method $m +done +``` + +## Zipformer PromptASR (zipformer + PromptASR + BERT text encoder) #### [zipformer_prompt_asr](./zipformer_prompt_asr) diff --git a/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py b/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py new file mode 100755 index 000000000..010531db2 --- /dev/null +++ b/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the Libriheavy dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomChunkyWriter, +) + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest-dir", + type=str, + help="""The source directory that contains raw manifests. + """, + default="data/manifests", + ) + + parser.add_argument( + "--fbank-dir", + type=str, + help="""Fbank output dir + """, + default="data/fbank", + ) + + parser.add_argument( + "--subset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--num-workers", + type=int, + default=20, + help="Number of dataloading workers used for reading the audio.", + ) + + parser.add_argument( + "--batch-duration", + type=float, + default=600.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Whether to use speed perturbation.", + ) + + parser.add_argument( + "--use-splits", + type=str2bool, + default=False, + help="Whether to compute fbank on splits.", + ) + + parser.add_argument( + "--num-splits", + type=int, + help="""The number of splits of the medium and large subset. + Only needed when --use-splits is true.""", + ) + + parser.add_argument( + "--start", + type=int, + default=0, + help="""Process pieces starting from this number (inclusive). + Only needed when --use-splits is true.""", + ) + + parser.add_argument( + "--stop", + type=int, + default=-1, + help="""Stop processing pieces until this number (exclusive). + Only needed when --use-splits is true.""", + ) + + return parser.parse_args() + + +def compute_fbank_libriheavy(args): + src_dir = Path(args.manifest_dir) + output_dir = Path(args.fbank_dir) + num_jobs = min(15, os.cpu_count()) + num_mel_bins = 80 + subset = args.subset + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + output_cuts_path = output_dir / f"libriheavy_cuts_{subset}.jsonl.gz" + if output_cuts_path.exists(): + logging.info(f"{output_cuts_path} exists - skipping") + return + + input_cuts_path = src_dir / f"libriheavy_cuts_{subset}.jsonl.gz" + assert input_cuts_path.exists(), f"{input_cuts_path} does not exist!" + logging.info(f"Loading {input_cuts_path}") + cut_set = CutSet.from_file(input_cuts_path) + + logging.info("Computing features") + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/libriheavy_feats_{subset}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + + logging.info(f"Saving to {output_cuts_path}") + cut_set.to_file(output_cuts_path) + + +def compute_fbank_libriheavy_splits(args): + num_splits = args.num_splits + subset = args.subset + src_dir = f"{args.manifest_dir}/libriheavy_{subset}_split" + src_dir = Path(src_dir) + output_dir = f"{args.fbank_dir}/libriheavy_{subset}_split" + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + start = args.start + stop = args.stop + if stop < start: + stop = num_splits + + stop = min(stop, num_splits) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + logging.info(f"device: {device}") + + num_digits = 8 # num_digits is fixed by lhotse split-lazy + for i in range(start, stop): + idx = f"{i + 1}".zfill(num_digits) + logging.info(f"Processing {idx}/{num_splits}") + + cuts_path = output_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + raw_cuts_path = src_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz" + if not raw_cuts_path.is_file(): + logging.info(f"{raw_cuts_path} does not exist - skipping it") + continue + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Computing features") + if (output_dir / f"libriheavy_feats_{subset}_{idx}.lca").exists(): + logging.info(f"Removing {output_dir}/libriheavy_feats_{subset}_{idx}.lca") + os.remove(output_dir / f"libriheavy_feats_{subset}_{idx}.lca") + + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/libriheavy_feats_{subset}_{idx}", + num_workers=args.num_workers, + batch_duration=args.batch_duration, + overwrite=True, + ) + + logging.info("About to split cuts into smaller chunks.") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + logging.info(f"Saved to {cuts_path}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + + if args.use_splits: + assert args.num_splits is not None, "Please provide num_splits" + compute_fbank_libriheavy_splits(args) + else: + compute_fbank_libriheavy(args) diff --git a/egs/libriheavy/ASR/local/compute_fbank_musan.py b/egs/libriheavy/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/libriheavy/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/local/norm_text.py b/egs/libriheavy/ASR/local/norm_text.py new file mode 100755 index 000000000..c2fc0d92d --- /dev/null +++ b/egs/libriheavy/ASR/local/norm_text.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import codecs +import sys + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--text", + type=str, + help="""Path to the input text. + """, + ) + return parser.parse_args() + + +def remove_punc_to_upper(text: str) -> str: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + s_list = [x.upper() if x in tokens else " " for x in text] + s = " ".join("".join(s_list).split()).strip() + return s + + +def main(): + args = get_args() + if args.text: + f = codecs.open(args.text, encoding="utf-8") + else: + f = codecs.getreader("utf-8")(sys.stdin.buffer) + + sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer) + line = f.readline() + while line: + print(remove_punc_to_upper(line)) + line = f.readline() + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/local/prepare_manifest.py b/egs/libriheavy/ASR/local/prepare_manifest.py new file mode 100755 index 000000000..42f392cae --- /dev/null +++ b/egs/libriheavy/ASR/local/prepare_manifest.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gzip +import json +import sys +from pathlib import Path + + +def simple_cleanup(text: str) -> str: + table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]") + text = text.translate(table) + return text.strip() + + +# Assign text of the supervisions and remove unnecessary entries. +def main(): + assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR" + fname = Path(sys.argv[1]).name + oname = Path(sys.argv[2]) / fname + with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout: + for line in fin: + cut = json.loads(line) + cut["supervisions"][0]["text"] = simple_cleanup( + cut["supervisions"][0]["custom"]["texts"][0] + ) + del cut["supervisions"][0]["custom"] + del cut["custom"] + fout.write((json.dumps(cut) + "\n").encode()) + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/local/train_bpe_model.py b/egs/libriheavy/ASR/local/train_bpe_model.py new file mode 100755 index 000000000..19caf43ab --- /dev/null +++ b/egs/libriheavy/ASR/local/train_bpe_model.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--byte-fallback", + action="store_true", + help="""Whether to enable byte_fallback when training bpe.""", + ) + + parser.add_argument( + "--character-coverage", + type=float, + default=1.0, + help="Character coverage in vocabulary.", + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=args.character_coverage, + user_defined_symbols=user_defined_symbols, + byte_fallback=args.byte_fallback, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/prepare.sh b/egs/libriheavy/ASR/prepare.sh new file mode 100755 index 000000000..af7e3c5b0 --- /dev/null +++ b/egs/libriheavy/ASR/prepare.sh @@ -0,0 +1,314 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 +export CUDA_VISIBLE_DEVICES="" + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/librilight +# You can find small, medium, large, etc. inside it. +# +# - $dl_dir/libriheavy +# You can find libriheavy_cuts_small.jsonl.gz, libriheavy_cuts_medium.jsonl.gz, etc. inside it. +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + # 1000 + 500 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +fbank_dir=data/fbank +manifests_dir=data/manifests + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download audio data." + # If you have pre-downloaded it to /path/to/librilight, + # you can create a symlink + # + # ln -sfv /path/to/librilight $dl_dir/librilight + # + mkdir -p $dl_dir/librilight + for subset in small medium large; do + log "Downloading ${subset} subset." + if [ ! -d $dl_dir/librilight/${subset} ]; then + wget -P $dl_dir/librilight -c https://dl.fbaipublicfiles.com/librilight/data/${subset}.tar + tar xf $dl_dir/librilight/${subset}.tar -C $dl_dir/librilight + else + log "Skipping download, ${subset} subset exists." + fi + done +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download manifests from huggingface." + + # If you have pre-downloaded it to /path/to/libriheavy, + # you can create a symlink + # + # ln -sfv /path/to/libriheavy $dl_dir/libriheavy + # + mkdir -p $dl_dir/libriheavy + for subset in small medium large dev test_clean test_other; do + if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz ]; then + log "Downloading ${subset} subset." + wget -P $dl_dir/libriheavy -c https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_${subset}.jsonl.gz + else + log "Skipping download, ${subset} subset exists." + fi + done + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Download manifests from modelscope" + mkdir -p $dl_dir/libriheavy + if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_small.jsonl.gz ]; then + cd $dl_dir/libriheavy + GIT_LFS_SKIP_SMUDGE=1 git clone https://www.modelscope.cn/datasets/pkufool/Libriheavy.git + cd Libriheavy + git lfs pull --exclude "raw/*" + mv *.jsonl.gz ../ + cd .. + rm -rf Libriheavy + cd ../../ + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + mkdir -p $manifests_dir + if [ ! -e $manifests_dir/.musan.done ]; then + lhotse prepare musan $dl_dir/musan $manifests_dir + touch $manifests_dir/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare Libriheavy manifests" + mkdir -p $manifests_dir + for subset in small medium large dev test_clean test_other; do + if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then + log "Prepare manifest for subset : ${subset}" + ./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir + fi + done +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p $fbank_dir + if [ ! -e $fbank_dir/.musan.done ]; then + ./local/compute_fbank_musan.py + touch $fbank_dir/.musan.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute fbank for small subset and validation subsets" + for subset in test_clean test_other dev small; do + log "Computing $subset subset." + if [ ! -e $fbank_dir/.libriheavy.${subset}.done ]; then + ./local/compute_fbank_libriheavy.py \ + --manifest-dir ${manifests_dir} \ + --subset ${subset} \ + --fbank-dir $fbank_dir \ + --num-workers $nj + fi + done +fi + +num_per_split=8000 +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Split medium and large subsets." + for subset in medium large; do + log "Spliting subset : $subset" + split_dir=$manifests_dir/libriheavy_${subset}_split + mkdir -p $split_dir + if [ ! -e $split_dir/.split_completed ]; then + lhotse split-lazy $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz $split_dir $num_per_split + touch $split_dir/.split_completed + fi + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Compute fbank for medium and large subsets" + mkdir -p $fbank_dir + chunk_size=20 + for subset in medium large; do + if [ $subset == "large" ]; then + chunk_size=200 + fi + num_splits=$(find $manifests_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz" | wc -l) + if [ ! -e $fbank_dir/.libriheavy.${subset}.done ]; then + for i in $(seq 0 1 6); do + start=$(( i * $chunk_size )) + end=$(( (i+1) * $chunk_size )) + ./local/compute_fbank_libriheavy.py \ + --manifest-dir ${manifests_dir} \ + --use-splits 1 \ + --subset ${subset} \ + --fbank-dir $fbank_dir \ + --num-splits $num_splits \ + --num-workers $nj \ + --start $start \ + --stop $end & + done + wait + touch $fbank_dir/.libriheavy.${subset}.done + fi + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Combine features for medium and large subsets." + for subset in medium large; do + log "Combining $subset subset." + if [ ! -f $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then + pieces=$(find $fbank_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz") + lhotse combine $pieces $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz + fi + done +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Train BPE model for normalized text" + + if [ ! -f data/texts ]; then + gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ + | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \ + | ./local/norm_text.py > data/texts + fi + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + cp data/texts $lang_dir/text + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/text + fi + done +fi + + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Train BPE model for unnormalized text" + if [ ! -f data/punc_texts ]; then + gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ + | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts + fi + for vocab_size in ${vocab_sizes[@]}; do + new_vacab_size = $(($vocab_size + 256)) + lang_dir=data/lang_punc_bpe_${new_vocab_size} + mkdir -p $lang_dir + + cp data/punc_texts $lang_dir/text + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --byte-fallback \ + --vocab-size ${new_vocab_size} \ + --byte-fallback \ + --character-coverage 0.99 \ + --transcript $lang_dir/text + fi + done +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Prepare language model for normalized text" + + for subset in small medium large; do + if [ ! -f $manifests_dir/texts_${subset} ]; then + gunzip -c $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz \ + | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \ + | ./local/norm_text.py > $manifests_dir/texts_${subset} + fi + done + + mkdir -p data/lm + if [ ! -f data/lm/text ]; then + cat $manifests_dir/texts_small $manifests_dir/texts_medium $manifests_dir/texts_large > data/lm/text + fi + + (echo ' 0'; echo '!SIL 1'; echo ' 2'; echo ' 3';) \ + > data/lm/words.txt + + cat data/lm/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ + | awk '{print $1" "NR+3}' >> data/lm/words.txt + + num_lines=$(< data/lm/words.txt wc -l) + (echo "#0 $num_lines"; echo " $(($num_lines + 1))"; echo " $(($num_lines + 2))";) \ + >> data/lm/words.txt + + # Train LM on transcripts + if [ ! -f data/lm/3-gram.unpruned.arpa ]; then + python3 ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text data/lm/text \ + -lm data/lm/3-gram.unpruned.arpa + fi + + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table=data/lm/words.txt \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt + fi +fi + diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py new file mode 100644 index 000000000..df761c1b8 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py @@ -0,0 +1,443 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriHeavyAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--subset", + type=str, + default="S", + help="""The subset to be used. Should be S, M or L. Note: S subset + includes libriheavy_cuts_small.jsonl.gz, M subset includes + libriheavy_cuts_small.jsonl.gz and libriheavy_cuts_medium.jsonl.gz, + L subset includes libriheavy_cuts_small.jsonl.gz, + libriheavy_cuts_medium.jsonl.gz and libriheavy_cuts_large.jsonl.gz. + """, + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_small_cuts(self) -> CutSet: + logging.info("About to get small subset cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz" + ) + + @lru_cache() + def train_medium_cuts(self) -> CutSet: + logging.info("About to get medium subset cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz" + ) + + @lru_cache() + def train_large_cuts(self) -> CutSet: + logging.info("About to get large subset cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz" + ) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get the test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get the test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test_other.jsonl.gz" + ) diff --git a/egs/libriheavy/ASR/zipformer/beam_search.py b/egs/libriheavy/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/decode.py b/egs/libriheavy/ASR/zipformer/decode.py new file mode 100644 index 000000000..1928e2635 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/decode.py @@ -0,0 +1,794 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +""" + + +import argparse +import logging +import math +import warnings +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from text_normalization import remove_punc_to_upper +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--train-with-punctuation", + type=str2bool, + default=False, + help="""Set to True, if the model was trained on texts with casing + and punctuation.""", + ) + + parser.add_argument( + "--post-normalization", + type=str2bool, + default=False, + help="""Upper case and remove all chars except ' and - + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + this_batch = [] + if params.post_normalization and params.train_with_punctuation: + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = remove_punc_to_upper(ref_text).split() + hyp_words = remove_punc_to_upper(" ".join(hyp_words)).split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[f"{name}_norm"].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + libriheavy = LibriHeavyAsrDataModule(args) + + def normalize_text(c: Cut): + text = remove_punc_to_upper(c.supervisions[0].text) + c.supervisions[0].text = text + return c + + test_clean_cuts = libriheavy.test_clean_cuts() + test_other_cuts = libriheavy.test_other_cuts() + + if not params.train_with_punctuation: + test_clean_cuts = test_clean_cuts.map(normalize_text) + test_other_cuts = test_other_cuts.map(normalize_text) + + test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts) + test_other_dl = libriheavy.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer/decoder.py b/egs/libriheavy/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/encoder_interface.py b/egs/libriheavy/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/export-onnx.py b/egs/libriheavy/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/export.py b/egs/libriheavy/ASR/zipformer/export.py new file mode 120000 index 000000000..dfc1bec08 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/jit_pretrained.py b/egs/libriheavy/ASR/zipformer/jit_pretrained.py new file mode 120000 index 000000000..25108391f --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/joiner.py b/egs/libriheavy/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/model.py b/egs/libriheavy/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/onnx_decode.py b/egs/libriheavy/ASR/zipformer/onnx_decode.py new file mode 120000 index 000000000..0573b88c5 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/onnx_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/onnx_pretrained.py b/egs/libriheavy/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/optim.py b/egs/libriheavy/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/pretrained.py b/egs/libriheavy/ASR/zipformer/pretrained.py new file mode 120000 index 000000000..0bd71dde4 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/scaling.py b/egs/libriheavy/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/scaling_coverter.py b/egs/libriheavy/ASR/zipformer/scaling_coverter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/scaling_coverter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/subsampling.py b/egs/libriheavy/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/text_normalization.py b/egs/libriheavy/ASR/zipformer/text_normalization.py new file mode 100644 index 000000000..92590769c --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/text_normalization.py @@ -0,0 +1,50 @@ +from num2words import num2words + + +def remove_punc_to_upper(text: str) -> str: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + s_list = [x.upper() if x in tokens else " " for x in text] + s = " ".join("".join(s_list).split()).strip() + return s + + +def word_normalization(word: str) -> str: + # 1. Use full word for some abbreviation + # 2. Convert digits to english words + # 3. Convert ordinal number to english words + if word == "MRS": + return "MISSUS" + if word == "MR": + return "MISTER" + if word == "ST": + return "SAINT" + if word == "ECT": + return "ET CETERA" + + if word[-2:] in ("ST", "ND", "RD", "TH") and word[:-2].isnumeric(): # e.g 9TH, 6TH + word = num2words(word[:-2], to="ordinal") + word = word.replace("-", " ") + + if word.isnumeric(): + num = int(word) + if num > 1500 and num < 2030: + word = num2words(word, to="year") + else: + word = num2words(word) + word = word.replace("-", " ") + return word.upper() + + +def text_normalization(text: str) -> str: + text = text.upper() + return " ".join([word_normalization(x) for x in text.split()]) + + +if __name__ == "__main__": + assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK" + assert ( + text_normalization("Hello Mrs st 21st world 3rd she 99th MR") + == "HELLO MISSUS SAINT TWENTY FIRST WORLD THIRD SHE NINETY NINTH MISTER" + ) diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py new file mode 100644 index 000000000..c97da4a11 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/train.py @@ -0,0 +1,1415 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from text_normalization import remove_punc_to_upper +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-hours", + type=float, + default=30000, + help="""Number of hours that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--train-with-punctuation", + type=str2bool, + default=False, + help="If True, the training text will include casing and punctuation.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + # Use the number of hours of speech to adjust the learning rate + scheduler.step_epoch( + params.batch_idx_train * params.max_duration * params.world_size / 3600 + ) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_hours) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def normalize_text(c: Cut): + text = remove_punc_to_upper(c.supervisions[0].text) + c.supervisions[0].text = text + return c + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 2.0 or c.duration > 30.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + libriheavy = LibriHeavyAsrDataModule(args) + + train_cuts = libriheavy.train_small_cuts() + if params.subset == "M" or params.subset == "L": + train_cuts += libriheavy.train_medium_cuts() + if params.subset == "L": + train_cuts += libriheavy.train_large_cuts() + + if not params.train_with_punctuation: + train_cuts = train_cuts.map(normalize_text) + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = libriheavy.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = libriheavy.dev_cuts() + + if not params.train_with_punctuation: + valid_cuts = valid_cuts.map(normalize_text) + + valid_dl = libriheavy.valid_dataloaders(valid_cuts) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer/zipformer.py b/egs/libriheavy/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/requirements-ci.txt b/requirements-ci.txt index e1232a768..6c74f688c 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -17,6 +17,7 @@ six git+https://github.com/lhotse-speech/lhotse kaldilm==1.11 kaldialign==0.7.1 +num2words sentencepiece==0.1.96 tensorboard==2.8.0 typeguard==2.13.3 diff --git a/requirements.txt b/requirements.txt index 5a8326619..9502fcbd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ kaldifst kaldilm kaldialign +num2words kaldi-decoder sentencepiece>=0.1.96 tensorboard From ae67f75e9c429d35e8a84d6d70cc8050eae37c86 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 26 Nov 2023 10:04:15 +0800 Subject: [PATCH 105/113] a bilingual recipe similar to the `multi-zh_hans` (#1265) --- ...rmer.sh => run-multi-corpora-zipformer.sh} | 38 + ...er.yml => run-multi-corpora-zipformer.yml} | 10 +- egs/multi_zh_en/ASR/README.md | 19 + egs/multi_zh_en/ASR/RESULTS.md | 44 + egs/multi_zh_en/ASR/local/compile_lg.py | 1 + egs/multi_zh_en/ASR/local/prepare_char.py | 1 + .../ASR/local/prepare_for_bpe_model.py | 65 + egs/multi_zh_en/ASR/local/prepare_lang.py | 1 + .../ASR/local/prepare_lang_bbpe.py | 1 + egs/multi_zh_en/ASR/local/prepare_lang_bpe.py | 1 + egs/multi_zh_en/ASR/local/prepare_words.py | 1 + egs/multi_zh_en/ASR/local/text2segments.py | 1 + egs/multi_zh_en/ASR/local/text2token.py | 1 + egs/multi_zh_en/ASR/local/train_bbpe_model.py | 1 + .../ASR/local/validate_bpe_lexicon.py | 1 + egs/multi_zh_en/ASR/prepare.sh | 149 ++ egs/multi_zh_en/ASR/shared | 1 + .../ASR/zipformer/asr_datamodule.py | 385 +++++ egs/multi_zh_en/ASR/zipformer/beam_search.py | 1 + egs/multi_zh_en/ASR/zipformer/decode.py | 851 ++++++++++ egs/multi_zh_en/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + .../ASR/zipformer/export-onnx-streaming.py | 1 + egs/multi_zh_en/ASR/zipformer/export-onnx.py | 1 + egs/multi_zh_en/ASR/zipformer/export.py | 541 +++++++ .../ASR/zipformer/generate_averaged_model.py | 193 +++ .../ASR/zipformer/jit_pretrained.py | 1 + .../ASR/zipformer/jit_pretrained_ctc.py | 1 + .../ASR/zipformer/jit_pretrained_streaming.py | 1 + egs/multi_zh_en/ASR/zipformer/joiner.py | 1 + egs/multi_zh_en/ASR/zipformer/model.py | 1 + .../ASR/zipformer/multi_dataset.py | 247 +++ egs/multi_zh_en/ASR/zipformer/onnx_check.py | 1 + egs/multi_zh_en/ASR/zipformer/onnx_decode.py | 1 + .../zipformer/onnx_pretrained-streaming.py | 1 + .../ASR/zipformer/onnx_pretrained.py | 1 + egs/multi_zh_en/ASR/zipformer/optim.py | 1 + egs/multi_zh_en/ASR/zipformer/pretrained.py | 378 +++++ egs/multi_zh_en/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + .../ASR/zipformer/streaming_beam_search.py | 1 + .../ASR/zipformer/streaming_decode.py | 1 + egs/multi_zh_en/ASR/zipformer/subsampling.py | 1 + egs/multi_zh_en/ASR/zipformer/train.py | 1416 +++++++++++++++++ egs/multi_zh_en/ASR/zipformer/zipformer.py | 1 + 45 files changed, 4363 insertions(+), 5 deletions(-) rename .github/scripts/{run-multi-zh_hans-zipformer.sh => run-multi-corpora-zipformer.sh} (66%) rename .github/workflows/{run-multi-zh_hans-zipformer.yml => run-multi-corpora-zipformer.yml} (91%) create mode 100644 egs/multi_zh_en/ASR/README.md create mode 100644 egs/multi_zh_en/ASR/RESULTS.md create mode 120000 egs/multi_zh_en/ASR/local/compile_lg.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_char.py create mode 100755 egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_lang.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_lang_bpe.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_words.py create mode 120000 egs/multi_zh_en/ASR/local/text2segments.py create mode 120000 egs/multi_zh_en/ASR/local/text2token.py create mode 120000 egs/multi_zh_en/ASR/local/train_bbpe_model.py create mode 120000 egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py create mode 100755 egs/multi_zh_en/ASR/prepare.sh create mode 120000 egs/multi_zh_en/ASR/shared create mode 100644 egs/multi_zh_en/ASR/zipformer/asr_datamodule.py create mode 120000 egs/multi_zh_en/ASR/zipformer/beam_search.py create mode 100755 egs/multi_zh_en/ASR/zipformer/decode.py create mode 120000 egs/multi_zh_en/ASR/zipformer/decoder.py create mode 120000 egs/multi_zh_en/ASR/zipformer/encoder_interface.py create mode 120000 egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py create mode 120000 egs/multi_zh_en/ASR/zipformer/export-onnx.py create mode 100755 egs/multi_zh_en/ASR/zipformer/export.py create mode 100755 egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py create mode 120000 egs/multi_zh_en/ASR/zipformer/jit_pretrained.py create mode 120000 egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py create mode 120000 egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py create mode 120000 egs/multi_zh_en/ASR/zipformer/joiner.py create mode 120000 egs/multi_zh_en/ASR/zipformer/model.py create mode 100644 egs/multi_zh_en/ASR/zipformer/multi_dataset.py create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_check.py create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_decode.py create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/multi_zh_en/ASR/zipformer/optim.py create mode 100755 egs/multi_zh_en/ASR/zipformer/pretrained.py create mode 120000 egs/multi_zh_en/ASR/zipformer/scaling.py create mode 120000 egs/multi_zh_en/ASR/zipformer/scaling_converter.py create mode 120000 egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py create mode 120000 egs/multi_zh_en/ASR/zipformer/streaming_decode.py create mode 120000 egs/multi_zh_en/ASR/zipformer/subsampling.py create mode 100755 egs/multi_zh_en/ASR/zipformer/train.py create mode 120000 egs/multi_zh_en/ASR/zipformer/zipformer.py diff --git a/.github/scripts/run-multi-zh_hans-zipformer.sh b/.github/scripts/run-multi-corpora-zipformer.sh similarity index 66% rename from .github/scripts/run-multi-zh_hans-zipformer.sh rename to .github/scripts/run-multi-corpora-zipformer.sh index cbd86a4d3..90f859f43 100755 --- a/.github/scripts/run-multi-zh_hans-zipformer.sh +++ b/.github/scripts/run-multi-corpora-zipformer.sh @@ -95,3 +95,41 @@ for method in modified_beam_search fast_beam_search; do $repo/test_wavs/DEV_T0000000001.wav \ $repo/test_wavs/DEV_T0000000002.wav done + +rm -rf $repo + +cd ../../../egs/multi_zh_en/ASR +log "==== Test icefall-asr-zipformer-multi-zh-en-2023-11-22 ====" +repo_url=https://huggingface.co/zrjin/icefall-asr-zipformer-multi-zh-en-2023-11-22/ + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +./zipformer/pretrained.py \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bbpe_2000/bbpe.model \ + --method greedy_search \ +$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_29.wav \ +$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_55.wav \ +$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_75.wav + +for method in modified_beam_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bbpe_2000/bbpe.model \ + $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_29.wav \ + $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_55.wav \ + $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_75.wav +done + +rm -rf $repo diff --git a/.github/workflows/run-multi-zh_hans-zipformer.yml b/.github/workflows/run-multi-corpora-zipformer.yml similarity index 91% rename from .github/workflows/run-multi-zh_hans-zipformer.yml rename to .github/workflows/run-multi-corpora-zipformer.yml index 72c0775a7..38f7eb908 100644 --- a/.github/workflows/run-multi-zh_hans-zipformer.yml +++ b/.github/workflows/run-multi-corpora-zipformer.yml @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: run-multi-zh_hans-zipformer +name: run-multi-corpora-zipformer on: push: @@ -24,12 +24,12 @@ on: types: [labeled] concurrency: - group: run_multi-zh_hans_zipformer-${{ github.ref }} + group: run_multi-corpora_zipformer-${{ github.ref }} cancel-in-progress: true jobs: - run_multi-zh_hans_zipformer: - if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' || github.event.label.name == 'zipformer' + run_multi-corpora_zipformer: + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' || github.event.label.name == 'zipformer' || github.event.label.name == 'multi-corpora' runs-on: ${{ matrix.os }} strategy: matrix: @@ -81,4 +81,4 @@ jobs: export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-multi-zh_hans-zipformer.sh + .github/scripts/run-multi-corpora-zipformer.sh diff --git a/egs/multi_zh_en/ASR/README.md b/egs/multi_zh_en/ASR/README.md new file mode 100644 index 000000000..29341571d --- /dev/null +++ b/egs/multi_zh_en/ASR/README.md @@ -0,0 +1,19 @@ +# Introduction + +This recipe includes scripts for training Zipformer model using both English and Chinese datasets. + +# Included Training Sets + +1. LibriSpeech (English) +2. AiShell-2 (Chinese) +3. TAL-CSASR (Code-Switching, Chinese and English) + +|Datset| Number of hours| URL| +|---|---:|---| +|**TOTAL**|2,547|---| +|LibriSpeech|960|https://www.openslr.org/12/| +|AiShell-2|1,000|http://www.aishelltech.com/aishell_2| +|TAL-CSASR|587|https://ai.100tal.com/openData/voice| + + + diff --git a/egs/multi_zh_en/ASR/RESULTS.md b/egs/multi_zh_en/ASR/RESULTS.md new file mode 100644 index 000000000..3562d6ac3 --- /dev/null +++ b/egs/multi_zh_en/ASR/RESULTS.md @@ -0,0 +1,44 @@ +## Results + +### Zh-En datasets bpe-based training results (Non-streaming) on Zipformer model + +This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1265) in icefall. + +#### Non-streaming (Byte-Level BPE vocab_size=2000) + +Best results (num of params : ~69M): + +The training command: + +``` +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 35 \ + --use-fp16 1 \ + --max-duration 1000 \ + --num-workers 8 +``` + +The decoding command: + +``` +for method in greedy_search modified_beam_search fast_beam_search; do + ./zipformer/decode.py \ + --epoch 34 \ + --avg 19 \ + --decoding-method $method +done +``` + +Word Error Rates (WERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model (# tokens is 2000). + +| Datasets | TAL-CSASR | TAL-CSASR | AiShell-2 | AiShell-2 | LibriSpeech | LibriSpeech | +|----------------------|-----------|-----------|-----------|-----------|-------------|-------------| +| Zipformer WER (%) | dev | test | dev | test | test-clean | test-other | +| greedy_search | 6.65 | 6.69 | 6.57 | 7.03 | 2.43 | 5.70 | +| modified_beam_search | 6.46 | 6.51 | 6.18 | 6.60 | 2.41 | 5.57 | +| fast_beam_search | 6.57 | 6.68 | 6.40 | 6.74 | 2.40 | 5.56 | + +Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-zipformer-multi-zh-en-2023-11-22, which is trained on LibriSpeech 960-hour training set (with speed perturbation), TAL-CSASR training set (with speed perturbation) and AiShell-2 (w/o speed perturbation). + + diff --git a/egs/multi_zh_en/ASR/local/compile_lg.py b/egs/multi_zh_en/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/multi_zh_en/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_char.py b/egs/multi_zh_en/ASR/local/prepare_char.py new file mode 120000 index 000000000..42743b544 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_char.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py b/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py new file mode 100755 index 000000000..00514e6bb --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script tokenizes the training transcript by CJK characters +# and saves the result to transcript_chars.txt, which is used +# to train the BPE model later. + +import argparse +from pathlib import Path + +from tqdm.auto import tqdm + +from icefall.utils import tokenize_by_CJK_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Output directory. + The generated transcript_chars.txt is saved to this directory. + """, + ) + + parser.add_argument( + "--text", + type=str, + help="Training transcript.", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + text = Path(args.text) + + assert lang_dir.exists() and text.exists(), f"{lang_dir} or {text} does not exist!" + + transcript_path = lang_dir / "transcript_chars.txt" + + with open(text, "r", encoding="utf-8") as fin: + with open(transcript_path, "w+", encoding="utf-8") as fout: + for line in tqdm(fin): + fout.write(tokenize_by_CJK_char(line) + "\n") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh_en/ASR/local/prepare_lang.py b/egs/multi_zh_en/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py b/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py new file mode 120000 index 000000000..9a0b44642 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_lang_bbpe.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py b/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_words.py b/egs/multi_zh_en/ASR/local/prepare_words.py new file mode 120000 index 000000000..ef2b4eaf3 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_words.py @@ -0,0 +1 @@ +../../../aishell2/ASR/local/prepare_words.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/text2segments.py b/egs/multi_zh_en/ASR/local/text2segments.py new file mode 120000 index 000000000..7d68a39c3 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/text2segments.py @@ -0,0 +1 @@ +../../../wenetspeech/ASR/local/text2segments.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/text2token.py b/egs/multi_zh_en/ASR/local/text2token.py new file mode 120000 index 000000000..ce5cfd537 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/text2token.py @@ -0,0 +1 @@ +../../../wenetspeech/ASR/local/text2token.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/train_bbpe_model.py b/egs/multi_zh_en/ASR/local/train_bbpe_model.py new file mode 120000 index 000000000..7fb4a9f9d --- /dev/null +++ b/egs/multi_zh_en/ASR/local/train_bbpe_model.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/train_bbpe_model.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py b/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/prepare.sh b/egs/multi_zh_en/ASR/prepare.sh new file mode 100755 index 000000000..9f2be5a5c --- /dev/null +++ b/egs/multi_zh_en/ASR/prepare.sh @@ -0,0 +1,149 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +vocab_sizes=( + 2000 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +log "Dataset: musan" +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Soft link fbank of musan" + mkdir -p data/fbank + if [ -e ../../librispeech/ASR/data/fbank/.musan.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_feats) . + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_cuts.jsonl.gz) . + cd ../.. + else + log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 4 --stop-stage 4" + exit 1 + fi +fi + +log "Dataset: LibriSpeech" +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Soft link fbank of LibriSpeech" + mkdir -p data/fbank + if [ -e ../../librispeech/ASR/data/fbank/.librispeech.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts*) . + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats*) . + cd ../.. + else + log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 3 --stop-stage 3" + exit 1 + fi +fi + +log "Dataset: AiShell-2" +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Soft link fbank of AiShell-2" + mkdir -p data/fbank + if [ -e ../../aishell2/ASR/data/fbank/.aishell2.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts*) . + ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats*) . + cd ../.. + else + log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" + exit 1 + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare Byte BPE based lang" + mkdir -p data/fbank + if [ ! -d ../../aishell2/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then + log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" + exit 1 + fi + + if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then + log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 6 --stop-stage 6" + exit 1 + fi + + cd data/ + if [ ! -d ./lang_char ]; then + ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) . + fi + if [ ! -d ./lang_bpe_500 ]; then + ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) . + fi + cd ../ + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + mkdir -p $lang_dir + + cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \ + > $lang_dir/text + + if [ ! -f $lang_dir/transcript_chars.txt ]; then + ./local/prepare_for_bpe_model.py \ + --lang-dir ./$lang_dir \ + --text $lang_dir/text + fi + + if [ ! -f $lang_dir/text_words_segmentation ]; then + python3 ./local/text2segments.py \ + --input-file ./data/lang_char/text \ + --output-file $lang_dir/text_words_segmentation + + cat ./data/lang_bpe_500/transcript_words.txt \ + >> $lang_dir/text_words_segmentation + + cat ./data/lang_char/text \ + >> $lang_dir/text + fi + + cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \ + | sort -u | sed '/^$/d' | uniq > $lang_dir/words_no_ids.txt + + if [ ! -f $lang_dir/words.txt ]; then + python3 ./local/prepare_words.py \ + --input-file $lang_dir/words_no_ids.txt \ + --output-file $lang_dir/words.txt + fi + + if [ ! -f $lang_dir/bbpe.model ]; then + ./local/train_bbpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/text + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bbpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bbpe.model + fi + done +fi + diff --git a/egs/multi_zh_en/ASR/shared b/egs/multi_zh_en/ASR/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/multi_zh_en/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py new file mode 100644 index 000000000..be6e94472 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py @@ -0,0 +1,385 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class AsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=300.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl diff --git a/egs/multi_zh_en/ASR/zipformer/beam_search.py b/egs/multi_zh_en/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..8e2c0a65c --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/decode.py b/egs/multi_zh_en/ASR/zipformer/decode.py new file mode 100755 index 000000000..e21e8f052 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/decode.py @@ -0,0 +1,851 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from multi_dataset import MultiDataset +from train import add_model_arguments, get_model, get_params + +from icefall import byte_encode, smart_byte_decode, tokenize_by_CJK_char +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_2000/bbpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bbpe_2000", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-tal-csasr", + type=str2bool, + default=False, + help="Whether to use TAL-CSASR training data.", + ) + + parser.add_argument( + "--use-librispeech", + type=str2bool, + default=False, + help="Whether to use LibriSpeech training data.", + ) + + parser.add_argument( + "--use-aishell2", + type=str2bool, + default=False, + help="Whether to use Aishell-2 training data.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode( + byte_encode(tokenize_by_CJK_char(supervisions["text"])) + ), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(smart_byte_decode(sp.decode(hyp)).split()) + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [tokenize_by_CJK_char(str(text)).split() for text in texts] + # print(texts) + # exit() + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" + ) + return T > 0 + + test_sets_cuts = multi_dataset.test_cuts() + + test_sets = test_sets_cuts.keys() + test_dl = [ + data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) + for cuts_name in test_sets + ] + + for test_set, test_dl in zip(test_sets, test_dl): + logging.info(f"Start decoding test set: {test_set}") + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh_en/ASR/zipformer/decoder.py b/egs/multi_zh_en/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/encoder_interface.py b/egs/multi_zh_en/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py b/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/export-onnx.py b/egs/multi_zh_en/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/export.py b/egs/multi_zh_en/ASR/zipformer/export.py new file mode 100755 index 000000000..fbd9ce0dd --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/export.py @@ -0,0 +1,541 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 \ + --jit 1 + +It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. +You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. + +Check ./jit_pretrained_streaming.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +- For non-streaming model: + +To use the generated file with `zipformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bbpe_2000/bpe.model + +- For streaming model: + +To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + + # simulated streaming decoding + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bbpe_2000/bpe.model + + # chunk-wise streaming decoding + ./zipformer/streaming_decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bbpe_2000/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +- non-streaming model: +https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ + # You will find the pre-trained models in exp dir +""" + +import argparse +import logging +import re +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor, nn +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, str2bool + + +def num_tokens( + token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") +) -> int: + """Return the number of tokens excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + symbols = token_table.symbols + ans = [] + for s in symbols: + if not disambig_pattern.match(s): + ans.append(token_table[s]) + num_tokens = len(ans) + if 0 in ans: + num_tokens -= 1 + return num_tokens + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bbpe_2000/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +class StreamingEncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + assert len(encoder.chunk_size) == 1, encoder.chunk_size + assert len(encoder.left_context_frames) == 1, encoder.left_context_frames + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """Streaming forward for encoder_embed and encoder. + + Args: + features: (N, T, C) + feature_lengths: (N,) + states: a list of Tensors + + Returns encoder outputs, output lengths, and updated states. + """ + chunk_size = self.chunk_size + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lengths, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + # if torch.cuda.is_available(): + # device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + + # Wrap encoder and encoder_embed as a module + if params.causal: + model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) + chunk_size = model.encoder.chunk_size + left_context_len = model.encoder.left_context_len + filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + else: + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py b/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py new file mode 100755 index 000000000..68111fad7 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) use the checkpoint exp_dir/epoch-xxx.pt +./zipformer/generate_averaged_model.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp + +It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15.pt")`. + +(2) use the checkpoint exp_dir/checkpoint-iter.pt +./zipformer/generate_averaged_model.py \ + --iter 22000 \ + --avg 5 \ + --exp-dir ./zipformer/exp + +It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5.pt")`. +""" + + +import argparse +from pathlib import Path + +import k2 +import torch +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + print("Script started") + + device = torch.device("cpu") + print(f"Device: {device}") + + symbol_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = symbol_table[""] + params.unk_id = symbol_table[""] + params.vocab_size = len(symbol_table) + + print("About to create model") + model = get_model(params) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + print( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + print( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py new file mode 120000 index 000000000..25108391f --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py new file mode 120000 index 000000000..9a8da5844 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py new file mode 120000 index 000000000..1962351e9 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/joiner.py b/egs/multi_zh_en/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/model.py b/egs/multi_zh_en/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/multi_dataset.py b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py new file mode 100644 index 000000000..1155a3dcc --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py @@ -0,0 +1,247 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Dict + +from lhotse import CutSet, load_manifest_lazy + + +class MultiDataset: + def __init__(self, args: argparse.Namespace): + """ + Args: + manifest_dir: + It is expected to contain the following files: + - aishell2_cuts_train.jsonl.gz + """ + self.fbank_dir = Path(args.manifest_dir) + self.use_tal_csasr = args.use_tal_csasr + self.use_librispeech = args.use_librispeech + self.use_aishell2 = args.use_aishell2 + + def train_cuts(self) -> CutSet: + logging.info("About to get multidataset train cuts") + + # AISHELL-2 + if self.use_aishell2: + logging.info("Loading Aishell-2 in lazy mode") + aishell_2_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_train.jsonl.gz" + ) + + # TAL-CSASR + if self.use_tal_csasr: + logging.info("Loading TAL-CSASR in lazy mode") + tal_csasr_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_train_set.jsonl.gz" + ) + + # LibriSpeech + if self.use_librispeech: + logging.info("Loading LibriSpeech in lazy mode") + train_clean_100_cuts = self.train_clean_100_cuts() + train_clean_360_cuts = self.train_clean_360_cuts() + train_other_500_cuts = self.train_other_500_cuts() + + if self.use_tal_csasr and self.use_librispeech and self.use_aishell2: + return CutSet.mux( + aishell_2_cuts, + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + tal_csasr_cuts, + weights=[ + len(aishell_2_cuts), + len(train_clean_100_cuts), + len(train_clean_360_cuts), + len(train_other_500_cuts), + len(tal_csasr_cuts), + ], + ) + elif not self.use_tal_csasr and self.use_librispeech and self.use_aishell2: + return CutSet.mux( + aishell_2_cuts, + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + weights=[ + len(aishell_2_cuts), + len(train_clean_100_cuts), + len(train_clean_360_cuts), + len(train_other_500_cuts), + ], + ) + elif self.use_tal_csasr and not self.use_librispeech and self.use_aishell2: + return CutSet.mux( + aishell_2_cuts, + tal_csasr_cuts, + weights=[ + len(aishell_2_cuts), + len(tal_csasr_cuts), + ], + ) + elif self.use_tal_csasr and self.use_librispeech and not self.use_aishell2: + return CutSet.mux( + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + tal_csasr_cuts, + weights=[ + len(train_clean_100_cuts), + len(train_clean_360_cuts), + len(train_other_500_cuts), + len(tal_csasr_cuts), + ], + ) + else: + raise NotImplementedError( + f"""Not implemented for + use_aishell2: {self.use_aishell2} + use_librispeech: {self.use_librispeech} + use_tal_csasr: {self.use_tal_csasr}""" + ) + + def dev_cuts(self) -> CutSet: + logging.info("About to get multidataset dev cuts") + + # AISHELL-2 + logging.info("Loading Aishell-2 DEV set in lazy mode") + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) + + # LibriSpeech + dev_clean_cuts = self.dev_clean_cuts() + dev_other_cuts = self.dev_other_cuts() + + logging.info("Loading TAL-CSASR set in lazy mode") + tal_csasr_dev_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz" + ) + + return CutSet.mux( + aishell2_dev_cuts, + dev_clean_cuts, + dev_other_cuts, + tal_csasr_dev_cuts, + weights=[ + len(aishell2_dev_cuts), + len(dev_clean_cuts), + len(dev_other_cuts), + len(tal_csasr_dev_cuts), + ], + ) + + def test_cuts(self) -> Dict[str, CutSet]: + logging.info("About to get multidataset test cuts") + + # AISHELL-2 + if self.use_aishell2: + logging.info("Loading Aishell-2 set in lazy mode") + aishell2_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_test.jsonl.gz" + ) + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) + + # LibriSpeech + if self.use_librispeech: + test_clean_cuts = self.test_clean_cuts() + test_other_cuts = self.test_other_cuts() + + logging.info("Loading TAL-CSASR set in lazy mode") + tal_csasr_test_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_test_set.jsonl.gz" + ) + tal_csasr_dev_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz" + ) + + test_cuts = { + "tal_csasr_test": tal_csasr_test_cuts, + "tal_csasr_dev": tal_csasr_dev_cuts, + } + + if self.use_aishell2: + test_cuts.update( + { + "aishell-2_test": aishell2_test_cuts, + "aishell-2_dev": aishell2_dev_cuts, + } + ) + if self.use_librispeech: + test_cuts.update( + { + "librispeech_test_clean": test_clean_cuts, + "librispeech_test_other": test_other_cuts, + } + ) + return test_cuts + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_test-other.jsonl.gz" + ) diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_check.py b/egs/multi_zh_en/ASR/zipformer/onnx_check.py new file mode 120000 index 000000000..f3dd42004 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_decode.py b/egs/multi_zh_en/ASR/zipformer/onnx_decode.py new file mode 120000 index 000000000..0573b88c5 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/onnx_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py new file mode 120000 index 000000000..cfea104c2 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/optim.py b/egs/multi_zh_en/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/pretrained.py b/egs/multi_zh_en/ASR/zipformer/pretrained.py new file mode 100755 index 000000000..676272e1f --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/pretrained.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 23 \ + --avg 1 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 23 \ + --avg 1 + +Usage of this script: + +- For non-streaming model: + +(1) greedy search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +- For streaming model: + +(1) greedy search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +You can also use `./zipformer/exp/epoch-xx.pt`. + +Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + fast_beam_search_one_best, + greedy_search_batch, + modified_beam_search, +) +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall import smart_byte_decode + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to byte-level bpe model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # model forward + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + + hyps = [] + msg = f"Using {params.method}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + raise ValueError(f"Unsupported method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + s += f"{filename}:\n{hyp}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh_en/ASR/zipformer/scaling.py b/egs/multi_zh_en/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/scaling_converter.py b/egs/multi_zh_en/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py b/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_decode.py b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py new file mode 120000 index 000000000..13fd02a78 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_decode.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/subsampling.py b/egs/multi_zh_en/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py new file mode 100755 index 000000000..310c8fe59 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/train.py @@ -0,0 +1,1416 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from multi_dataset import MultiDataset +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import byte_encode, diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, + tokenize_by_CJK_char, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_2000/bbpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-tal-csasr", + type=str2bool, + default=False, + help="Whether to use TAL-CSASR training data.", + ) + + parser.add_argument( + "--use-librispeech", + type=str2bool, + default=False, + help="Whether to use LibriSpeech training data.", + ) + + parser.add_argument( + "--use-aishell2", + type=str2bool, + default=False, + help="Whether to use Aishell-2 training data.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args) + + train_cuts = multi_dataset.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 12 seconds + # + # Caution: There is a reason to select 12.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + def tokenize_and_encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = byte_encode(tokenize_by_CJK_char(text)) + c.supervisions[0].text = text + return c + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_cuts = train_cuts.map(tokenize_and_encode_text) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = data_module.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = multi_dataset.dev_cuts() + valid_dl = data_module.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh_en/ASR/zipformer/zipformer.py b/egs/multi_zh_en/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file From 0622dea30deacf2680dcca0549f7a05c0b965066 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Wed, 29 Nov 2023 21:28:38 +0800 Subject: [PATCH 106/113] Add a TTS recipe VITS on LJSpeech dataset (#1372) * first commit * replace phonimizer with g2p * use Conformer as text encoder * modify training script, clean codes * rename directory * convert text to tokens in data preparation stage * fix tts_datamodule.py * support onnx export and testing the exported onnx model * add doc * add README.md * fix style --- .flake8 | 2 +- docs/source/recipes/TTS/index.rst | 7 + docs/source/recipes/TTS/ljspeech/vits.rst | 113 +++ docs/source/recipes/index.rst | 3 +- .../TTS/local/compute_spectrogram_ljspeech.py | 106 ++ .../TTS/local/display_manifest_statistics.py | 73 ++ egs/ljspeech/TTS/local/prepare_token_file.py | 104 ++ .../TTS/local/prepare_tokens_ljspeech.py | 59 ++ egs/ljspeech/TTS/local/validate_manifest.py | 70 ++ egs/ljspeech/TTS/prepare.sh | 117 +++ egs/ljspeech/TTS/shared/parse_options.sh | 1 + egs/ljspeech/TTS/vits/README.md | 3 + egs/ljspeech/TTS/vits/duration_predictor.py | 194 ++++ egs/ljspeech/TTS/vits/export-onnx.py | 261 +++++ egs/ljspeech/TTS/vits/flow.py | 312 ++++++ egs/ljspeech/TTS/vits/generator.py | 531 ++++++++++ egs/ljspeech/TTS/vits/hifigan.py | 933 ++++++++++++++++++ egs/ljspeech/TTS/vits/infer.py | 233 +++++ egs/ljspeech/TTS/vits/loss.py | 336 +++++++ .../TTS/vits/monotonic_align/__init__.py | 81 ++ .../TTS/vits/monotonic_align/core.pyx | 51 + .../TTS/vits/monotonic_align/setup.py | 31 + egs/ljspeech/TTS/vits/posterior_encoder.py | 117 +++ egs/ljspeech/TTS/vits/residual_coupling.py | 229 +++++ egs/ljspeech/TTS/vits/test_onnx.py | 123 +++ egs/ljspeech/TTS/vits/text_encoder.py | 662 +++++++++++++ egs/ljspeech/TTS/vits/tokenizer.py | 106 ++ egs/ljspeech/TTS/vits/train.py | 893 +++++++++++++++++ egs/ljspeech/TTS/vits/transform.py | 218 ++++ egs/ljspeech/TTS/vits/tts_datamodule.py | 325 ++++++ egs/ljspeech/TTS/vits/utils.py | 265 +++++ egs/ljspeech/TTS/vits/vits.py | 610 ++++++++++++ egs/ljspeech/TTS/vits/wavenet.py | 349 +++++++ pyproject.toml | 1 + 34 files changed, 7517 insertions(+), 2 deletions(-) create mode 100644 docs/source/recipes/TTS/index.rst create mode 100644 docs/source/recipes/TTS/ljspeech/vits.rst create mode 100755 egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py create mode 100755 egs/ljspeech/TTS/local/display_manifest_statistics.py create mode 100755 egs/ljspeech/TTS/local/prepare_token_file.py create mode 100755 egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py create mode 100755 egs/ljspeech/TTS/local/validate_manifest.py create mode 100755 egs/ljspeech/TTS/prepare.sh create mode 120000 egs/ljspeech/TTS/shared/parse_options.sh create mode 100644 egs/ljspeech/TTS/vits/README.md create mode 100644 egs/ljspeech/TTS/vits/duration_predictor.py create mode 100755 egs/ljspeech/TTS/vits/export-onnx.py create mode 100644 egs/ljspeech/TTS/vits/flow.py create mode 100644 egs/ljspeech/TTS/vits/generator.py create mode 100644 egs/ljspeech/TTS/vits/hifigan.py create mode 100755 egs/ljspeech/TTS/vits/infer.py create mode 100644 egs/ljspeech/TTS/vits/loss.py create mode 100644 egs/ljspeech/TTS/vits/monotonic_align/__init__.py create mode 100644 egs/ljspeech/TTS/vits/monotonic_align/core.pyx create mode 100644 egs/ljspeech/TTS/vits/monotonic_align/setup.py create mode 100644 egs/ljspeech/TTS/vits/posterior_encoder.py create mode 100644 egs/ljspeech/TTS/vits/residual_coupling.py create mode 100755 egs/ljspeech/TTS/vits/test_onnx.py create mode 100644 egs/ljspeech/TTS/vits/text_encoder.py create mode 100644 egs/ljspeech/TTS/vits/tokenizer.py create mode 100755 egs/ljspeech/TTS/vits/train.py create mode 100644 egs/ljspeech/TTS/vits/transform.py create mode 100644 egs/ljspeech/TTS/vits/tts_datamodule.py create mode 100644 egs/ljspeech/TTS/vits/utils.py create mode 100644 egs/ljspeech/TTS/vits/vits.py create mode 100644 egs/ljspeech/TTS/vits/wavenet.py diff --git a/.flake8 b/.flake8 index 410cb5482..cf276d0ba 100644 --- a/.flake8 +++ b/.flake8 @@ -15,7 +15,7 @@ per-file-ignores = egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203 egs/librispeech/ASR/zipformer/*.py: E501, E203 egs/librispeech/ASR/RESULTS.md: E999, - + egs/ljspeech/TTS/vits/*.py: E501, E203 # invalid escape sequence (cause by tex formular), W605 icefall/utils.py: E501, W605 diff --git a/docs/source/recipes/TTS/index.rst b/docs/source/recipes/TTS/index.rst new file mode 100644 index 000000000..aa891c072 --- /dev/null +++ b/docs/source/recipes/TTS/index.rst @@ -0,0 +1,7 @@ +TTS +====== + +.. toctree:: + :maxdepth: 2 + + ljspeech/vits diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst new file mode 100644 index 000000000..385fd3c70 --- /dev/null +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -0,0 +1,113 @@ +VITS +=============== + +This tutorial shows you how to train an VITS model +with the `LJSpeech `_ dataset. + +.. note:: + + The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_ + + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/ljspeech/TTS + $ ./prepare.sh + +To run stage 1 to stage 5, use + +.. code-block:: bash + + $ ./prepare.sh --stage 1 --stop_stage 5 + + +Build Monotonic Alignment Search +-------------------------------- + +.. code-block:: bash + + $ cd vits/monotonic_align + $ python setup.py build_ext --inplace + $ cd ../../ + + +Training +-------- + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0,1,2,3" + $ ./vits/train.py \ + --world-size 4 \ + --num-epochs 1000 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + --max-duration 500 + +.. note:: + + You can adjust the hyper-parameters to control the size of the VITS model and + the training configurations. For more details, please run ``./vits/train.py --help``. + +.. note:: + + The training can take a long time (usually a couple of days). + +Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``. + + +Inference +--------- + +The inference part uses checkpoints saved by the training part, so you have to run the +training part first. It will save the ground-truth and generated wavs to the directory +``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``. + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0" + $ ./vits/infer.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + --max-duration 500 + +.. note:: + + For more details, please run ``./vits/infer.py --help``. + + +Export models +------------- + +Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``: +``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``. + +.. code-block:: bash + + $ ./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +You can test the exported ONNX model with: + +.. code-block:: bash + + $ ./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following link: + + - ``_ diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index 7265e1cf6..8df61f0d0 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -2,7 +2,7 @@ Recipes ======= This page contains various recipes in ``icefall``. -Currently, only speech recognition recipes are provided. +Currently, we provide recipes for speech recognition, language model, and speech synthesis. We may add recipes for other tasks as well in the future. @@ -16,3 +16,4 @@ We may add recipes for other tasks as well in the future. Non-streaming-ASR/index Streaming-ASR/index RNN-LM/index + TTS/index diff --git a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py new file mode 100755 index 000000000..97c9008fc --- /dev/null +++ b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LJSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated spectrogram features are saved in data/spectrogram. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + LilcomChunkyWriter, + Spectrogram, + SpectrogramConfig, + load_manifest, +) +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_spectrogram_ljspeech(): + src_dir = Path("data/manifests") + output_dir = Path("data/spectrogram") + num_jobs = min(4, os.cpu_count()) + + sampling_rate = 22050 + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + use_fft_mag = True + + prefix = "ljspeech" + suffix = "jsonl.gz" + partition = "all" + + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet + ) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet + ) + + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + ) + extractor = Spectrogram(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{cuts_filename} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_spectrogram_ljspeech() diff --git a/egs/ljspeech/TTS/local/display_manifest_statistics.py b/egs/ljspeech/TTS/local/display_manifest_statistics.py new file mode 100755 index 000000000..93f0044f0 --- /dev/null +++ b/egs/ljspeech/TTS/local/display_manifest_statistics.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in vits/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + path = "./data/spectrogram/ljspeech_cuts_all.jsonl.gz" + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +Cut statistics: + ╒═══════════════════════════╤══════════╕ + │ Cuts count: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Total duration (hh:mm:ss) │ 23:55:18 │ + ├───────────────────────────┼──────────┤ + │ mean │ 6.6 │ + ├───────────────────────────┼──────────┤ + │ std │ 2.2 │ + ├───────────────────────────┼──────────┤ + │ min │ 1.1 │ + ├───────────────────────────┼──────────┤ + │ 25% │ 5.0 │ + ├───────────────────────────┼──────────┤ + │ 50% │ 6.8 │ + ├───────────────────────────┼──────────┤ + │ 75% │ 8.4 │ + ├───────────────────────────┼──────────┤ + │ 99% │ 10.0 │ + ├───────────────────────────┼──────────┤ + │ 99.5% │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ 99.9% │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ max │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ Recordings available: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Features available: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Supervisions available: │ 13100 │ + ╘═══════════════════════════╧══════════╛ +""" diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py new file mode 100755 index 000000000..df976804a --- /dev/null +++ b/egs/ljspeech/TTS/local/prepare_token_file.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and generates the file that maps tokens to IDs. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict + +from lhotse import load_manifest + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest-file", + type=Path, + default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"), + help="Path to the manifest file", + ) + + parser.add_argument( + "--tokens", + type=Path, + default=Path("data/tokens.txt"), + help="Path to the tokens", + ) + + return parser.parse_args() + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_token2id(manifest_file: Path) -> Dict[str, int]: + """Return a dict that maps token to IDs.""" + extra_tokens = [ + "", # 0 for blank + "", # 1 for sos and eos symbols. + "", # 2 for OOV + ] + all_tokens = set() + + cut_set = load_manifest(manifest_file) + + for cut in cut_set: + # Each cut only contain one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + for t in cut.tokens: + all_tokens.add(t) + + all_tokens = extra_tokens + list(all_tokens) + + token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)} + return token2id + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + manifest_file = Path(args.manifest_file) + out_file = Path(args.tokens) + + token2id = get_token2id(manifest_file) + write_mapping(out_file, token2id) diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py new file mode 100755 index 000000000..fcd0137a0 --- /dev/null +++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with phoneme tokens. +""" + +import logging +from pathlib import Path + +import g2p_en +import tacotron_cleaner.cleaners +from lhotse import CutSet, load_manifest + + +def prepare_tokens_ljspeech(): + output_dir = Path("data/spectrogram") + prefix = "ljspeech" + suffix = "jsonl.gz" + partition = "all" + + cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + g2p = g2p_en.G2p() + + new_cuts = [] + for cut in cut_set: + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + text = cut.supervisions[0].normalized_text + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + cut.tokens = g2p(text) + new_cuts.append(cut) + + new_cut_set = CutSet.from_cuts(new_cuts) + new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_tokens_ljspeech() diff --git a/egs/ljspeech/TTS/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py new file mode 100755 index 000000000..68159ae03 --- /dev/null +++ b/egs/ljspeech/TTS/local/validate_manifest.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/spectrogram/ljspeech_cuts_all.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset.speech_synthesis import validate_for_tts + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet), type(cut_set) + + validate_for_tts(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh new file mode 100755 index 000000000..8ee40896e --- /dev/null +++ b/egs/ljspeech/TTS/prepare.sh @@ -0,0 +1,117 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=1 +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # The directory $dl_dir/LJSpeech-1.1 will contain: + # - wavs, which contains the audio files + # - metadata.csv, which provides the transcript text for each audio clip + + # If you have pre-downloaded it to /path/to/LJSpeech-1.1, you can create a symlink + # + # ln -sfv /path/to/LJSpeech-1.1 $dl_dir/LJSpeech-1.1 + # + if [ ! -d $dl_dir/LJSpeech-1.1 ]; then + lhotse download ljspeech $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LJSpeech manifest" + # We assume that you have downloaded the LJSpeech corpus + # to $dl_dir/LJSpeech + mkdir -p data/manifests + if [ ! -e data/manifests/.ljspeech.done ]; then + lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests + touch data/manifests/.ljspeech.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute spectrogram for LJSpeech" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.ljspeech.done ]; then + ./local/compute_spectrogram_ljspeech.py + touch data/spectrogram/.ljspeech.done + fi + + if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then + log "Validating data/spectrogram for LJSpeech" + python3 ./local/validate_manifest.py \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz + touch data/spectrogram/.ljspeech-validated.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare phoneme tokens for LJSpeech" + if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then + ./local/prepare_tokens_ljspeech.py + mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz + touch data/spectrogram/.ljspeech_with_token.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Split the LJSpeech cuts into train, valid and test sets" + if [ ! -e data/spectrogram/.ljspeech_split.done ]; then + lhotse subset --last 600 \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \ + data/spectrogram/ljspeech_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \ + data/spectrogram/ljspeech_cuts_test.jsonl.gz + + rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 )) + lhotse subset --first $n \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_train.jsonl.gz + touch data/spectrogram/.ljspeech_split.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate token file" + # We assume you have installed g2p_en and espnet_tts_frontend. + # If not, please install them with: + # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p + # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ + if [ ! -e data/tokens.txt ]; then + ./local/prepare_token_file.py \ + --manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \ + --tokens data/tokens.txt + fi +fi + + diff --git a/egs/ljspeech/TTS/shared/parse_options.sh b/egs/ljspeech/TTS/shared/parse_options.sh new file mode 120000 index 000000000..e4665e7de --- /dev/null +++ b/egs/ljspeech/TTS/shared/parse_options.sh @@ -0,0 +1 @@ +../../../librispeech/ASR/shared/parse_options.sh \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits/README.md b/egs/ljspeech/TTS/vits/README.md new file mode 100644 index 000000000..1141326b9 --- /dev/null +++ b/egs/ljspeech/TTS/vits/README.md @@ -0,0 +1,3 @@ +See https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html for detailed tutorials. + +Training logs, Tensorboard logs, and checkpoints are uploaded to https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29. diff --git a/egs/ljspeech/TTS/vits/duration_predictor.py b/egs/ljspeech/TTS/vits/duration_predictor.py new file mode 100644 index 000000000..c29a28479 --- /dev/null +++ b/egs/ljspeech/TTS/vits/duration_predictor.py @@ -0,0 +1,194 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Stochastic duration predictor modules in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import math +from typing import Optional + +import torch +import torch.nn.functional as F + +from flow import ( + ConvFlow, + DilatedDepthSeparableConv, + ElementwiseAffineFlow, + FlipFlow, + LogFlow, +) + + +class StochasticDurationPredictor(torch.nn.Module): + """Stochastic duration predictor module. + + This is a module of stochastic duration predictor described in `Conditional + Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + + """ + + def __init__( + self, + channels: int = 192, + kernel_size: int = 3, + dropout_rate: float = 0.5, + flows: int = 4, + dds_conv_layers: int = 3, + global_channels: int = -1, + ): + """Initialize StochasticDurationPredictor module. + + Args: + channels (int): Number of channels. + kernel_size (int): Kernel size. + dropout_rate (float): Dropout rate. + flows (int): Number of flows. + dds_conv_layers (int): Number of conv layers in DDS conv. + global_channels (int): Number of global conditioning channels. + + """ + super().__init__() + + self.pre = torch.nn.Conv1d(channels, channels, 1) + self.dds = DilatedDepthSeparableConv( + channels, + kernel_size, + layers=dds_conv_layers, + dropout_rate=dropout_rate, + ) + self.proj = torch.nn.Conv1d(channels, channels, 1) + + self.log_flow = LogFlow() + self.flows = torch.nn.ModuleList() + self.flows += [ElementwiseAffineFlow(2)] + for i in range(flows): + self.flows += [ + ConvFlow( + 2, + channels, + kernel_size, + layers=dds_conv_layers, + ) + ] + self.flows += [FlipFlow()] + + self.post_pre = torch.nn.Conv1d(1, channels, 1) + self.post_dds = DilatedDepthSeparableConv( + channels, + kernel_size, + layers=dds_conv_layers, + dropout_rate=dropout_rate, + ) + self.post_proj = torch.nn.Conv1d(channels, channels, 1) + self.post_flows = torch.nn.ModuleList() + self.post_flows += [ElementwiseAffineFlow(2)] + for i in range(flows): + self.post_flows += [ + ConvFlow( + 2, + channels, + kernel_size, + layers=dds_conv_layers, + ) + ] + self.post_flows += [FlipFlow()] + + if global_channels > 0: + self.global_conv = torch.nn.Conv1d(global_channels, channels, 1) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + w: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + noise_scale: float = 1.0, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T_text). + x_mask (Tensor): Mask tensor (B, 1, T_text). + w (Optional[Tensor]): Duration tensor (B, 1, T_text). + g (Optional[Tensor]): Global conditioning tensor (B, channels, 1) + inverse (bool): Whether to inverse the flow. + noise_scale (float): Noise scale value. + + Returns: + Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,). + If inverse, log-duration tensor (B, 1, T_text). + + """ + x = x.detach() # stop gradient + x = self.pre(x) + if g is not None: + x = x + self.global_conv(g.detach()) # stop gradient + x = self.dds(x, x_mask) + x = self.proj(x) * x_mask + + if not inverse: + assert w is not None, "w must be provided." + h_w = self.post_pre(w) + h_w = self.post_dds(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = ( + torch.randn( + w.size(0), + 2, + w.size(2), + ).to(device=x.device, dtype=x.dtype) + * x_mask + ) + z_q = e_q + logdet_tot_q = 0.0 + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum( + (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] + ) + logq = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) + - logdet_tot_q + ) + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in self.flows: + z, logdet = flow(z, x_mask, g=x, inverse=inverse) + logdet_tot = logdet_tot + logdet + nll = ( + torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) + - logdet_tot + ) + return nll + logq # (B,) + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = ( + torch.randn( + x.size(0), + 2, + x.size(2), + ).to(device=x.device, dtype=x.dtype) + * noise_scale + ) + for flow in flows: + z = flow(z, x_mask, g=x, inverse=inverse) + z0, z1 = z.split(1, 1) + logw = z0 + return logw diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py new file mode 100755 index 000000000..154de4bf4 --- /dev/null +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports a VITS model from PyTorch to ONNX. + +Export the model to ONNX: +./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +It will generate two files inside vits/exp: + - vits-epoch-1000.onnx + - vits-epoch-1000.int8.onnx (quantizated model) + +See ./test_onnx.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import torch +import torch.nn as nn +from onnxruntime.quantization import QuantType, quantize_dynamic +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for VITS generator.""" + + def __init__(self, model: nn.Module): + """ + Args: + model: + A VITS generator. + frame_shift: + The frame shift in samples. + """ + super().__init__() + self.model = model + + def forward( + self, + tokens: torch.Tensor, + tokens_lens: torch.Tensor, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of VITS.inference_batch + + Args: + tokens: + Input text token indexes (1, T_text) + tokens_lens: + Number of tokens of shape (1,) + noise_scale (float): + Noise scale parameter for flow. + noise_scale_dur (float): + Noise scale parameter for duration predictor. + alpha (float): + Alpha parameter to control the speed of generated speech. + + Returns: + Return a tuple containing: + - audio, generated wavform tensor, (B, T_wav) + """ + audio, _, _ = self.model.inference( + text=tokens, + text_lengths=tokens_lens, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + ) + return audio + + +def export_model_onnx( + model: nn.Module, + model_filename: str, + opset_version: int = 11, +) -> None: + """Export the given generator model to ONNX format. + The exported model has one input: + + - tokens, a tensor of shape (1, T_text); dtype is torch.int64 + + and it has one output: + + - audio, a tensor of shape (1, T'); dtype is torch.float32 + + Args: + model: + The VITS generator. + model_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + noise_scale_dur = torch.tensor([1], dtype=torch.float32) + alpha = torch.tensor([1], dtype=torch.float32) + + torch.onnx.export( + model, + (tokens, tokens_lens, noise_scale, noise_scale_dur, alpha), + model_filename, + verbose=False, + opset_version=opset_version, + input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"], + output_names=["audio"], + dynamic_axes={ + "tokens": {0: "N", 1: "T"}, + "tokens_lens": {0: "N"}, + "audio": {0: "N", 1: "T"}, + }, + ) + + meta_data = { + "model_type": "VITS", + "version": "1", + "model_author": "k2-fsa", + "comment": "VITS generator", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=model_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model = model.generator + model.to("cpu") + model.eval() + + model = OnnxModel(model=model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"generator parameters: {num_param}") + + suffix = f"epoch-{params.epoch}" + + opset_version = 13 + + logging.info("Exporting encoder") + model_filename = params.exp_dir / f"vits-{suffix}.onnx" + export_model_onnx( + model, + model_filename, + opset_version=opset_version, + ) + logging.info(f"Exported generator to {model_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx" + quantize_dynamic( + model_input=model_filename, + model_output=model_filename_int8, + weight_type=QuantType.QUInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/vits/flow.py b/egs/ljspeech/TTS/vits/flow.py new file mode 100644 index 000000000..206bd5e3e --- /dev/null +++ b/egs/ljspeech/TTS/vits/flow.py @@ -0,0 +1,312 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Basic Flow modules used in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import math +from typing import Optional, Tuple, Union + +import torch + +from transform import piecewise_rational_quadratic_transform + + +class FlipFlow(torch.nn.Module): + """Flip flow module.""" + + def forward( + self, x: torch.Tensor, *args, inverse: bool = False, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Flipped tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + x = torch.flip(x, [1]) + if not inverse: + logdet = x.new_zeros(x.size(0)) + return x, logdet + else: + return x + + +class LogFlow(torch.nn.Module): + """Log flow module.""" + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + inverse: bool = False, + eps: float = 1e-5, + **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_mask (Tensor): Mask tensor (B, 1, T). + inverse (bool): Whether to inverse the flow. + eps (float): Epsilon for log. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + if not inverse: + y = torch.log(torch.clamp_min(x, eps)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class ElementwiseAffineFlow(torch.nn.Module): + """Elementwise affine flow module.""" + + def __init__(self, channels: int): + """Initialize ElementwiseAffineFlow module. + + Args: + channels (int): Number of channels. + + """ + super().__init__() + self.channels = channels + self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1))) + self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1))) + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_lengths (Tensor): Length tensor (B,). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + if not inverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class Transpose(torch.nn.Module): + """Transpose module for torch.nn.Sequential().""" + + def __init__(self, dim1: int, dim2: int): + """Initialize Transpose module.""" + super().__init__() + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Transpose.""" + return x.transpose(self.dim1, self.dim2) + + +class DilatedDepthSeparableConv(torch.nn.Module): + """Dilated depth-separable conv module.""" + + def __init__( + self, + channels: int, + kernel_size: int, + layers: int, + dropout_rate: float = 0.0, + eps: float = 1e-5, + ): + """Initialize DilatedDepthSeparableConv module. + + Args: + channels (int): Number of channels. + kernel_size (int): Kernel size. + layers (int): Number of layers. + dropout_rate (float): Dropout rate. + eps (float): Epsilon for layer norm. + + """ + super().__init__() + + self.convs = torch.nn.ModuleList() + for i in range(layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs += [ + torch.nn.Sequential( + torch.nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ), + Transpose(1, 2), + torch.nn.LayerNorm( + channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + torch.nn.GELU(), + torch.nn.Conv1d( + channels, + channels, + 1, + ), + Transpose(1, 2), + torch.nn.LayerNorm( + channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + torch.nn.GELU(), + torch.nn.Dropout(dropout_rate), + ) + ] + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_mask (Tensor): Mask tensor (B, 1, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, channels, T). + + """ + if g is not None: + x = x + g + for f in self.convs: + y = f(x * x_mask) + x = x + y + return x * x_mask + + +class ConvFlow(torch.nn.Module): + """Convolutional flow module.""" + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + layers: int, + bins: int = 10, + tail_bound: float = 5.0, + ): + """Initialize ConvFlow module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size. + layers (int): Number of layers. + bins (int): Number of bins. + tail_bound (float): Tail bound value. + + """ + super().__init__() + self.half_channels = in_channels // 2 + self.hidden_channels = hidden_channels + self.bins = bins + self.tail_bound = tail_bound + + self.input_conv = torch.nn.Conv1d( + self.half_channels, + hidden_channels, + 1, + ) + self.dds_conv = DilatedDepthSeparableConv( + hidden_channels, + kernel_size, + layers, + dropout_rate=0.0, + ) + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels * (bins * 3 - 1), + 1, + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_mask (Tensor): Mask tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + xa, xb = x.split(x.size(1) // 2, 1) + h = self.input_conv(xa) + h = self.dds_conv(h, x_mask, g=g) + h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T) + + b, c, t = xa.shape + # (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1) + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) + + # TODO(kan-bayashi): Understand this calculation + denom = math.sqrt(self.hidden_channels) + unnorm_widths = h[..., : self.bins] / denom + unnorm_heights = h[..., self.bins : 2 * self.bins] / denom + unnorm_derivatives = h[..., 2 * self.bins :] + xb, logdet_abs = piecewise_rational_quadratic_transform( + xb, + unnorm_widths, + unnorm_heights, + unnorm_derivatives, + inverse=inverse, + tails="linear", + tail_bound=self.tail_bound, + ) + x = torch.cat([xa, xb], 1) * x_mask + logdet = torch.sum(logdet_abs * x_mask, [1, 2]) + if not inverse: + return x, logdet + else: + return x diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py new file mode 100644 index 000000000..efb0e254c --- /dev/null +++ b/egs/ljspeech/TTS/vits/generator.py @@ -0,0 +1,531 @@ +# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Generator module in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + + +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from icefall.utils import make_pad_mask + +from duration_predictor import StochasticDurationPredictor +from hifigan import HiFiGANGenerator +from posterior_encoder import PosteriorEncoder +from residual_coupling import ResidualAffineCouplingBlock +from text_encoder import TextEncoder +from utils import get_random_segments + + +class VITSGenerator(torch.nn.Module): + """Generator module in VITS, `Conditional Variational Autoencoder + with Adversarial Learning for End-to-End Text-to-Speech`. + """ + + def __init__( + self, + vocabs: int, + aux_channels: int = 513, + hidden_channels: int = 192, + spks: Optional[int] = None, + langs: Optional[int] = None, + spk_embed_dim: Optional[int] = None, + global_channels: int = -1, + segment_size: int = 32, + text_encoder_attention_heads: int = 2, + text_encoder_ffn_expand: int = 4, + text_encoder_cnn_module_kernel: int = 5, + text_encoder_blocks: int = 6, + text_encoder_dropout_rate: float = 0.1, + decoder_kernel_size: int = 7, + decoder_channels: int = 512, + decoder_upsample_scales: List[int] = [8, 8, 2, 2], + decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4], + decoder_resblock_kernel_sizes: List[int] = [3, 7, 11], + decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + use_weight_norm_in_decoder: bool = True, + posterior_encoder_kernel_size: int = 5, + posterior_encoder_layers: int = 16, + posterior_encoder_stacks: int = 1, + posterior_encoder_base_dilation: int = 1, + posterior_encoder_dropout_rate: float = 0.0, + use_weight_norm_in_posterior_encoder: bool = True, + flow_flows: int = 4, + flow_kernel_size: int = 5, + flow_base_dilation: int = 1, + flow_layers: int = 4, + flow_dropout_rate: float = 0.0, + use_weight_norm_in_flow: bool = True, + use_only_mean_in_flow: bool = True, + stochastic_duration_predictor_kernel_size: int = 3, + stochastic_duration_predictor_dropout_rate: float = 0.5, + stochastic_duration_predictor_flows: int = 4, + stochastic_duration_predictor_dds_conv_layers: int = 3, + ): + """Initialize VITS generator module. + + Args: + vocabs (int): Input vocabulary size. + aux_channels (int): Number of acoustic feature channels. + hidden_channels (int): Number of hidden channels. + spks (Optional[int]): Number of speakers. If set to > 1, assume that the + sids will be provided as the input and use sid embedding layer. + langs (Optional[int]): Number of languages. If set to > 1, assume that the + lids will be provided as the input and use sid embedding layer. + spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, + assume that spembs will be provided as the input. + global_channels (int): Number of global conditioning channels. + segment_size (int): Segment size for decoder. + text_encoder_attention_heads (int): Number of heads in conformer block + of text encoder. + text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block + of text encoder. + text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder. + text_encoder_blocks (int): Number of conformer blocks in text encoder. + text_encoder_dropout_rate (float): Dropout rate in conformer block of + text encoder. + decoder_kernel_size (int): Decoder kernel size. + decoder_channels (int): Number of decoder initial channels. + decoder_upsample_scales (List[int]): List of upsampling scales in decoder. + decoder_upsample_kernel_sizes (List[int]): List of kernel size for + upsampling layers in decoder. + decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks + in decoder. + decoder_resblock_dilations (List[List[int]]): List of list of dilations for + resblocks in decoder. + use_weight_norm_in_decoder (bool): Whether to apply weight normalization in + decoder. + posterior_encoder_kernel_size (int): Posterior encoder kernel size. + posterior_encoder_layers (int): Number of layers of posterior encoder. + posterior_encoder_stacks (int): Number of stacks of posterior encoder. + posterior_encoder_base_dilation (int): Base dilation of posterior encoder. + posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder. + use_weight_norm_in_posterior_encoder (bool): Whether to apply weight + normalization in posterior encoder. + flow_flows (int): Number of flows in flow. + flow_kernel_size (int): Kernel size in flow. + flow_base_dilation (int): Base dilation in flow. + flow_layers (int): Number of layers in flow. + flow_dropout_rate (float): Dropout rate in flow + use_weight_norm_in_flow (bool): Whether to apply weight normalization in + flow. + use_only_mean_in_flow (bool): Whether to use only mean in flow. + stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic + duration predictor. + stochastic_duration_predictor_dropout_rate (float): Dropout rate in + stochastic duration predictor. + stochastic_duration_predictor_flows (int): Number of flows in stochastic + duration predictor. + stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv + layers in stochastic duration predictor. + + """ + super().__init__() + self.segment_size = segment_size + self.text_encoder = TextEncoder( + vocabs=vocabs, + d_model=hidden_channels, + num_heads=text_encoder_attention_heads, + dim_feedforward=hidden_channels * text_encoder_ffn_expand, + cnn_module_kernel=text_encoder_cnn_module_kernel, + num_layers=text_encoder_blocks, + dropout=text_encoder_dropout_rate, + ) + self.decoder = HiFiGANGenerator( + in_channels=hidden_channels, + out_channels=1, + channels=decoder_channels, + global_channels=global_channels, + kernel_size=decoder_kernel_size, + upsample_scales=decoder_upsample_scales, + upsample_kernel_sizes=decoder_upsample_kernel_sizes, + resblock_kernel_sizes=decoder_resblock_kernel_sizes, + resblock_dilations=decoder_resblock_dilations, + use_weight_norm=use_weight_norm_in_decoder, + ) + self.posterior_encoder = PosteriorEncoder( + in_channels=aux_channels, + out_channels=hidden_channels, + hidden_channels=hidden_channels, + kernel_size=posterior_encoder_kernel_size, + layers=posterior_encoder_layers, + stacks=posterior_encoder_stacks, + base_dilation=posterior_encoder_base_dilation, + global_channels=global_channels, + dropout_rate=posterior_encoder_dropout_rate, + use_weight_norm=use_weight_norm_in_posterior_encoder, + ) + self.flow = ResidualAffineCouplingBlock( + in_channels=hidden_channels, + hidden_channels=hidden_channels, + flows=flow_flows, + kernel_size=flow_kernel_size, + base_dilation=flow_base_dilation, + layers=flow_layers, + global_channels=global_channels, + dropout_rate=flow_dropout_rate, + use_weight_norm=use_weight_norm_in_flow, + use_only_mean=use_only_mean_in_flow, + ) + # TODO(kan-bayashi): Add deterministic version as an option + self.duration_predictor = StochasticDurationPredictor( + channels=hidden_channels, + kernel_size=stochastic_duration_predictor_kernel_size, + dropout_rate=stochastic_duration_predictor_dropout_rate, + flows=stochastic_duration_predictor_flows, + dds_conv_layers=stochastic_duration_predictor_dds_conv_layers, + global_channels=global_channels, + ) + + self.upsample_factor = int(np.prod(decoder_upsample_scales)) + self.spks = None + if spks is not None and spks > 1: + assert global_channels > 0 + self.spks = spks + self.global_emb = torch.nn.Embedding(spks, global_channels) + self.spk_embed_dim = None + if spk_embed_dim is not None and spk_embed_dim > 0: + assert global_channels > 0 + self.spk_embed_dim = spk_embed_dim + self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels) + self.langs = None + if langs is not None and langs > 1: + assert global_channels > 0 + self.langs = langs + self.lang_emb = torch.nn.Embedding(langs, global_channels) + + # delayed import + from monotonic_align import maximum_path + + self.maximum_path = maximum_path + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + ]: + """Calculate forward propagation. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, aux_channels, T_feats). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + Tensor: Waveform tensor (B, 1, segment_size * upsample_factor). + Tensor: Duration negative log-likelihood (NLL) tensor (B,). + Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text). + Tensor: Segments start index tensor (B,). + Tensor: Text mask tensor (B, 1, T_text). + Tensor: Feature mask tensor (B, 1, T_feats). + tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + - Tensor: Posterior encoder hidden representation (B, H, T_feats). + - Tensor: Flow hidden representation (B, H, T_feats). + - Tensor: Expanded text encoder projected mean (B, H, T_feats). + - Tensor: Expanded text encoder projected scale (B, H, T_feats). + - Tensor: Posterior encoder projected mean (B, H, T_feats). + - Tensor: Posterior encoder projected scale (B, H, T_feats). + + """ + # forward text encoder + x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) + + # calculate global conditioning + g = None + if self.spks is not None: + # speaker one-hot vector embedding: (B, global_channels, 1) + g = self.global_emb(sids.view(-1)).unsqueeze(-1) + if self.spk_embed_dim is not None: + # pretreined speaker embedding, e.g., X-vector (B, global_channels, 1) + g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + if self.langs is not None: + # language one-hot vector embedding: (B, global_channels, 1) + g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + + # forward posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) + + # forward flow + z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) + + # monotonic alignment search + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text) + # (B, 1, T_text) + neg_x_ent_1 = torch.sum( + -0.5 * math.log(2 * math.pi) - logs_p, + [1], + keepdim=True, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_2 = torch.matmul( + -0.5 * (z_p**2).transpose(1, 2), + s_p_sq_r, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_3 = torch.matmul( + z_p.transpose(1, 2), + (m_p * s_p_sq_r), + ) + # (B, 1, T_text) + neg_x_ent_4 = torch.sum( + -0.5 * (m_p**2) * s_p_sq_r, + [1], + keepdim=True, + ) + # (B, T_feats, T_text) + neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 + # (B, 1, T_feats, T_text) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + # monotonic attention weight: (B, 1, T_feats, T_text) + attn = ( + self.maximum_path( + neg_x_ent, + attn_mask.squeeze(1), + ) + .unsqueeze(1) + .detach() + ) + + # forward duration predictor + w = attn.sum(2) # (B, 1, T_text) + dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) + dur_nll = dur_nll / torch.sum(x_mask) + + # expand the length to match with the feature sequence + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) + + # get random segments + z_segments, z_start_idxs = get_random_segments( + z, + feats_lengths, + self.segment_size, + ) + + # forward decoder with random segments + wav = self.decoder(z_segments, g=g) + + return ( + wav, + dur_nll, + attn, + z_start_idxs, + x_mask, + y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + ) + + def inference( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: Optional[torch.Tensor] = None, + feats_lengths: Optional[torch.Tensor] = None, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + dur: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference. + + Args: + text (Tensor): Input text index tensor (B, T_text,). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, aux_channels, T_feats,). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided, + skip the prediction of durations (i.e., teacher forcing). + noise_scale (float): Noise scale parameter for flow. + noise_scale_dur (float): Noise scale parameter for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length of acoustic feature sequence. + use_teacher_forcing (bool): Whether to use teacher forcing. + + Returns: + Tensor: Generated waveform tensor (B, T_wav). + Tensor: Monotonic attention weight tensor (B, T_feats, T_text). + Tensor: Duration tensor (B, T_text). + + """ + # encoder + x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) + x_mask = x_mask.to(x.dtype) + g = None + if self.spks is not None: + # (B, global_channels, 1) + g = self.global_emb(sids.view(-1)).unsqueeze(-1) + if self.spk_embed_dim is not None: + # (B, global_channels, 1) + g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + if self.langs is not None: + # (B, global_channels, 1) + g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + + if use_teacher_forcing: + # forward posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) + + # forward flow + z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) + + # monotonic alignment search + s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text) + # (B, 1, T_text) + neg_x_ent_1 = torch.sum( + -0.5 * math.log(2 * math.pi) - logs_p, + [1], + keepdim=True, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_2 = torch.matmul( + -0.5 * (z_p**2).transpose(1, 2), + s_p_sq_r, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_3 = torch.matmul( + z_p.transpose(1, 2), + (m_p * s_p_sq_r), + ) + # (B, 1, T_text) + neg_x_ent_4 = torch.sum( + -0.5 * (m_p**2) * s_p_sq_r, + [1], + keepdim=True, + ) + # (B, T_feats, T_text) + neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 + # (B, 1, T_feats, T_text) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + # monotonic attention weight: (B, 1, T_feats, T_text) + attn = self.maximum_path( + neg_x_ent, + attn_mask.squeeze(1), + ).unsqueeze(1) + dur = attn.sum(2) # (B, 1, T_text) + + # forward decoder with random segments + wav = self.decoder(z * y_mask, g=g) + else: + # duration + if dur is None: + logw = self.duration_predictor( + x, + x_mask, + g=g, + inverse=True, + noise_scale=noise_scale_dur, + ) + w = torch.exp(logw) * x_mask * alpha + dur = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long() + y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device) + y_mask = y_mask.to(x.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = self._generate_path(dur, attn_mask) + + # expand the length to match with the feature sequence + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + m_p = torch.matmul( + attn.squeeze(1), + m_p.transpose(1, 2), + ).transpose(1, 2) + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + logs_p = torch.matmul( + attn.squeeze(1), + logs_p.transpose(1, 2), + ).transpose(1, 2) + + # decoder + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, inverse=True) + wav = self.decoder((z * y_mask)[:, :, :max_len], g=g) + + return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1) + + def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Generate path a.k.a. monotonic attention. + + Args: + dur (Tensor): Duration tensor (B, 1, T_text). + mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text). + + Returns: + Tensor: Path tensor (B, 1, T_feats, T_text). + + """ + b, _, t_y, t_x = mask.shape + cum_dur = torch.cumsum(dur, -1) + cum_dur_flat = cum_dur.view(b * t_x) + path = torch.arange(t_y, dtype=dur.dtype, device=dur.device) + path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1) + # path = path.view(b, t_x, t_y).to(dtype=mask.dtype) + path = path.view(b, t_x, t_y).to(dtype=torch.float) + # path will be like (t_x = 3, t_y = 5): + # [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.], + # [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.], + # [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]] + path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1] + # path = path.to(dtype=mask.dtype) + return path.unsqueeze(1).transpose(2, 3) * mask diff --git a/egs/ljspeech/TTS/vits/hifigan.py b/egs/ljspeech/TTS/vits/hifigan.py new file mode 100644 index 000000000..589ac30f6 --- /dev/null +++ b/egs/ljspeech/TTS/vits/hifigan.py @@ -0,0 +1,933 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""HiFi-GAN Modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +import copy +import logging +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +import torch.nn.functional as F + + +class HiFiGANGenerator(torch.nn.Module): + """HiFiGAN generator module.""" + + def __init__( + self, + in_channels: int = 80, + out_channels: int = 1, + channels: int = 512, + global_channels: int = -1, + kernel_size: int = 7, + upsample_scales: List[int] = [8, 8, 2, 2], + upsample_kernel_sizes: List[int] = [16, 16, 4, 4], + resblock_kernel_sizes: List[int] = [3, 7, 11], + resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + use_additional_convs: bool = True, + bias: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + ): + """Initialize HiFiGANGenerator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + channels (int): Number of hidden representation channels. + global_channels (int): Number of global conditioning channels. + kernel_size (int): Kernel size of initial and final conv layer. + upsample_scales (List[int]): List of upsampling scales. + upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers. + resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks. + resblock_dilations (List[List[int]]): List of list of dilations for residual + blocks. + use_additional_convs (bool): Whether to use additional conv layers in + residual blocks. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + + """ + super().__init__() + + # check hyperparameters are valid + assert kernel_size % 2 == 1, "Kernel size must be odd number." + assert len(upsample_scales) == len(upsample_kernel_sizes) + assert len(resblock_dilations) == len(resblock_kernel_sizes) + + # define modules + self.upsample_factor = int(np.prod(upsample_scales) * out_channels) + self.num_upsamples = len(upsample_kernel_sizes) + self.num_blocks = len(resblock_kernel_sizes) + self.input_conv = torch.nn.Conv1d( + in_channels, + channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ) + self.upsamples = torch.nn.ModuleList() + self.blocks = torch.nn.ModuleList() + for i in range(len(upsample_kernel_sizes)): + assert upsample_kernel_sizes[i] == 2 * upsample_scales[i] + self.upsamples += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.ConvTranspose1d( + channels // (2**i), + channels // (2 ** (i + 1)), + upsample_kernel_sizes[i], + upsample_scales[i], + padding=upsample_scales[i] // 2 + upsample_scales[i] % 2, + output_padding=upsample_scales[i] % 2, + ), + ) + ] + for j in range(len(resblock_kernel_sizes)): + self.blocks += [ + ResidualBlock( + kernel_size=resblock_kernel_sizes[j], + channels=channels // (2 ** (i + 1)), + dilations=resblock_dilations[j], + bias=bias, + use_additional_convs=use_additional_convs, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + ) + ] + self.output_conv = torch.nn.Sequential( + # NOTE(kan-bayashi): follow official implementation but why + # using different slope parameter here? (0.1 vs. 0.01) + torch.nn.LeakyReLU(), + torch.nn.Conv1d( + channels // (2 ** (i + 1)), + out_channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ), + torch.nn.Tanh(), + ) + if global_channels > 0: + self.global_conv = torch.nn.Conv1d(global_channels, channels, 1) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward( + self, c: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, in_channels, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, out_channels, T). + + """ + c = self.input_conv(c) + if g is not None: + c = c + self.global_conv(g) + for i in range(self.num_upsamples): + c = self.upsamples[i](c) + cs = 0.0 # initialize + for j in range(self.num_blocks): + cs += self.blocks[i * self.num_blocks + j](c) + c = cs / self.num_blocks + c = self.output_conv(c) + + return c + + def reset_parameters(self): + """Reset parameters. + + This initialization follows the official implementation manner. + https://github.com/jik876/hifi-gan/blob/master/models.py + + """ + + def _reset_parameters(m: torch.nn.Module): + if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): + m.weight.data.normal_(0.0, 0.01) + logging.debug(f"Reset parameters in {m}.") + + self.apply(_reset_parameters) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m: torch.nn.Module): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d) or isinstance( + m, torch.nn.ConvTranspose1d + ): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def inference( + self, c: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Perform inference. + + Args: + c (torch.Tensor): Input tensor (T, in_channels). + g (Optional[Tensor]): Global conditioning tensor (global_channels, 1). + + Returns: + Tensor: Output tensor (T ** upsample_factor, out_channels). + + """ + if g is not None: + g = g.unsqueeze(0) + c = self.forward(c.transpose(1, 0).unsqueeze(0), g=g) + return c.squeeze(0).transpose(1, 0) + + +class ResidualBlock(torch.nn.Module): + """Residual block module in HiFiGAN.""" + + def __init__( + self, + kernel_size: int = 3, + channels: int = 512, + dilations: List[int] = [1, 3, 5], + bias: bool = True, + use_additional_convs: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + ): + """Initialize ResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + channels (int): Number of channels for convolution layer. + dilations (List[int]): List of dilation factors. + use_additional_convs (bool): Whether to use additional convolution layers. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + + """ + super().__init__() + self.use_additional_convs = use_additional_convs + self.convs1 = torch.nn.ModuleList() + if use_additional_convs: + self.convs2 = torch.nn.ModuleList() + assert kernel_size % 2 == 1, "Kernel size must be odd number." + for dilation in dilations: + self.convs1 += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation, + bias=bias, + padding=(kernel_size - 1) // 2 * dilation, + ), + ) + ] + if use_additional_convs: + self.convs2 += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + bias=bias, + padding=(kernel_size - 1) // 2, + ), + ) + ] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + + Returns: + Tensor: Output tensor (B, channels, T). + + """ + for idx in range(len(self.convs1)): + xt = self.convs1[idx](x) + if self.use_additional_convs: + xt = self.convs2[idx](xt) + x = xt + x + return x + + +class HiFiGANPeriodDiscriminator(torch.nn.Module): + """HiFiGAN period discriminator module.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + period: int = 3, + kernel_sizes: List[int] = [5, 3], + channels: int = 32, + downsample_scales: List[int] = [3, 3, 3, 3, 1], + max_downsample_channels: int = 1024, + bias: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + use_spectral_norm: bool = False, + ): + """Initialize HiFiGANPeriodDiscriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + period (int): Period. + kernel_sizes (list): Kernel sizes of initial conv layers and the final conv + layer. + channels (int): Number of initial channels. + downsample_scales (List[int]): List of downsampling scales. + max_downsample_channels (int): Number of maximum downsampling channels. + use_additional_convs (bool): Whether to use additional conv layers in + residual blocks. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_spectral_norm (bool): Whether to use spectral norm. + If set to true, it will be applied to all of the conv layers. + + """ + super().__init__() + assert len(kernel_sizes) == 2 + assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number." + assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number." + + self.period = period + self.convs = torch.nn.ModuleList() + in_chs = in_channels + out_chs = channels + for downsample_scale in downsample_scales: + self.convs += [ + torch.nn.Sequential( + torch.nn.Conv2d( + in_chs, + out_chs, + (kernel_sizes[0], 1), + (downsample_scale, 1), + padding=((kernel_sizes[0] - 1) // 2, 0), + ), + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + ) + ] + in_chs = out_chs + # NOTE(kan-bayashi): Use downsample_scale + 1? + out_chs = min(out_chs * 4, max_downsample_channels) + self.output_conv = torch.nn.Conv2d( + out_chs, + out_channels, + (kernel_sizes[1] - 1, 1), + 1, + padding=((kernel_sizes[1] - 1) // 2, 0), + ) + + if use_weight_norm and use_spectral_norm: + raise ValueError("Either use use_weight_norm or use_spectral_norm.") + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # apply spectral norm + if use_spectral_norm: + self.apply_spectral_norm() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, in_channels, T). + + Returns: + list: List of each layer's tensors. + + """ + # transform 1d to 2d -> (B, C, T/P, P) + b, c, t = x.shape + if t % self.period != 0: + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t += n_pad + x = x.view(b, c, t // self.period, self.period) + + # forward conv + outs = [] + for layer in self.convs: + x = layer(x) + outs += [x] + x = self.output_conv(x) + x = torch.flatten(x, 1, -1) + outs += [x] + + return outs + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def apply_spectral_norm(self): + """Apply spectral normalization module from all of the layers.""" + + def _apply_spectral_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv2d): + torch.nn.utils.spectral_norm(m) + logging.debug(f"Spectral norm is applied to {m}.") + + self.apply(_apply_spectral_norm) + + +class HiFiGANMultiPeriodDiscriminator(torch.nn.Module): + """HiFiGAN multi-period discriminator module.""" + + def __init__( + self, + periods: List[int] = [2, 3, 5, 7, 11], + discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + ): + """Initialize HiFiGANMultiPeriodDiscriminator module. + + Args: + periods (List[int]): List of periods. + discriminator_params (Dict[str, Any]): Parameters for hifi-gan period + discriminator module. The period parameter will be overwritten. + + """ + super().__init__() + self.discriminators = torch.nn.ModuleList() + for period in periods: + params = copy.deepcopy(discriminator_params) + params["period"] = period + self.discriminators += [HiFiGANPeriodDiscriminator(**params)] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of list of each discriminator outputs, which consists of each + layer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + + return outs + + +class HiFiGANScaleDiscriminator(torch.nn.Module): + """HiFi-GAN scale discriminator module.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + kernel_sizes: List[int] = [15, 41, 5, 3], + channels: int = 128, + max_downsample_channels: int = 1024, + max_groups: int = 16, + bias: int = True, + downsample_scales: List[int] = [2, 2, 4, 4, 1], + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + use_spectral_norm: bool = False, + ): + """Initilize HiFiGAN scale discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_sizes (List[int]): List of four kernel sizes. The first will be used + for the first conv layer, and the second is for downsampling part, and + the remaining two are for the last two output layers. + channels (int): Initial number of channels for conv layer. + max_downsample_channels (int): Maximum number of channels for downsampling + layers. + bias (bool): Whether to add bias parameter in convolution layers. + downsample_scales (List[int]): List of downsampling scales. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + use_spectral_norm (bool): Whether to use spectral norm. If set to true, it + will be applied to all of the conv layers. + + """ + super().__init__() + self.layers = torch.nn.ModuleList() + + # check kernel size is valid + assert len(kernel_sizes) == 4 + for ks in kernel_sizes: + assert ks % 2 == 1 + + # add first layer + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_channels, + channels, + # NOTE(kan-bayashi): Use always the same kernel size + kernel_sizes[0], + bias=bias, + padding=(kernel_sizes[0] - 1) // 2, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + + # add downsample layers + in_chs = channels + out_chs = channels + # NOTE(kan-bayashi): Remove hard coding? + groups = 4 + for downsample_scale in downsample_scales: + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, + out_chs, + kernel_size=kernel_sizes[1], + stride=downsample_scale, + padding=(kernel_sizes[1] - 1) // 2, + groups=groups, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + ) + ] + in_chs = out_chs + # NOTE(kan-bayashi): Remove hard coding? + out_chs = min(in_chs * 2, max_downsample_channels) + # NOTE(kan-bayashi): Remove hard coding? + groups = min(groups * 4, max_groups) + + # add final layers + out_chs = min(in_chs * 2, max_downsample_channels) + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, + out_chs, + kernel_size=kernel_sizes[2], + stride=1, + padding=(kernel_sizes[2] - 1) // 2, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + self.layers += [ + torch.nn.Conv1d( + out_chs, + out_channels, + kernel_size=kernel_sizes[3], + stride=1, + padding=(kernel_sizes[3] - 1) // 2, + bias=bias, + ), + ] + + if use_weight_norm and use_spectral_norm: + raise ValueError("Either use use_weight_norm or use_spectral_norm.") + + # apply weight norm + self.use_weight_norm = use_weight_norm + if use_weight_norm: + self.apply_weight_norm() + + # apply spectral norm + self.use_spectral_norm = use_spectral_norm + if use_spectral_norm: + self.apply_spectral_norm() + + # backward compatibility + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[Tensor]: List of output tensors of each layer. + + """ + outs = [] + for f in self.layers: + x = f(x) + outs += [x] + + return outs + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def apply_spectral_norm(self): + """Apply spectral normalization module from all of the layers.""" + + def _apply_spectral_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d): + torch.nn.utils.spectral_norm(m) + logging.debug(f"Spectral norm is applied to {m}.") + + self.apply(_apply_spectral_norm) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def remove_spectral_norm(self): + """Remove spectral normalization module from all of the layers.""" + + def _remove_spectral_norm(m): + try: + logging.debug(f"Spectral norm is removed from {m}.") + torch.nn.utils.remove_spectral_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_spectral_norm) + + def _load_state_dict_pre_hook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """Fix the compatibility of weight / spectral normalization issue. + + Some pretrained models are trained with configs that use weight / spectral + normalization, but actually, the norm is not applied. This causes the mismatch + of the parameters with configs. To solve this issue, when parameter mismatch + happens in loading pretrained model, we remove the norm from the current model. + + See also: + - https://github.com/espnet/espnet/pull/5240 + - https://github.com/espnet/espnet/pull/5249 + - https://github.com/kan-bayashi/ParallelWaveGAN/pull/409 + + """ + current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)] + if self.use_weight_norm and any( + [k.endswith("weight") for k in current_module_keys] + ): + logging.warning( + "It seems weight norm is not applied in the pretrained model but the" + " current model uses it. To keep the compatibility, we remove the norm" + " from the current model. This may cause unexpected behavior due to the" + " parameter mismatch in finetuning. To avoid this issue, please change" + " the following parameters in config to false:\n" + " - discriminator_params.follow_official_norm\n" + " - discriminator_params.scale_discriminator_params.use_weight_norm\n" + " - discriminator_params.scale_discriminator_params.use_spectral_norm\n" + "\n" + "See also:\n" + " - https://github.com/espnet/espnet/pull/5240\n" + " - https://github.com/espnet/espnet/pull/5249" + ) + self.remove_weight_norm() + self.use_weight_norm = False + for k in current_module_keys: + if k.endswith("weight_g") or k.endswith("weight_v"): + del state_dict[k] + + if self.use_spectral_norm and any( + [k.endswith("weight") for k in current_module_keys] + ): + logging.warning( + "It seems spectral norm is not applied in the pretrained model but the" + " current model uses it. To keep the compatibility, we remove the norm" + " from the current model. This may cause unexpected behavior due to the" + " parameter mismatch in finetuning. To avoid this issue, please change" + " the following parameters in config to false:\n" + " - discriminator_params.follow_official_norm\n" + " - discriminator_params.scale_discriminator_params.use_weight_norm\n" + " - discriminator_params.scale_discriminator_params.use_spectral_norm\n" + "\n" + "See also:\n" + " - https://github.com/espnet/espnet/pull/5240\n" + " - https://github.com/espnet/espnet/pull/5249" + ) + self.remove_spectral_norm() + self.use_spectral_norm = False + for k in current_module_keys: + if ( + k.endswith("weight_u") + or k.endswith("weight_v") + or k.endswith("weight_orig") + ): + del state_dict[k] + + +class HiFiGANMultiScaleDiscriminator(torch.nn.Module): + """HiFi-GAN multi-scale discriminator module.""" + + def __init__( + self, + scales: int = 3, + downsample_pooling: str = "AvgPool1d", + # follow the official implementation setting + downsample_pooling_params: Dict[str, Any] = { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + follow_official_norm: bool = False, + ): + """Initilize HiFiGAN multi-scale discriminator module. + + Args: + scales (int): Number of multi-scales. + downsample_pooling (str): Pooling module name for downsampling of the + inputs. + downsample_pooling_params (Dict[str, Any]): Parameters for the above pooling + module. + discriminator_params (Dict[str, Any]): Parameters for hifi-gan scale + discriminator module. + follow_official_norm (bool): Whether to follow the norm setting of the + official implementaion. The first discriminator uses spectral norm + and the other discriminators use weight norm. + + """ + super().__init__() + self.discriminators = torch.nn.ModuleList() + + # add discriminators + for i in range(scales): + params = copy.deepcopy(discriminator_params) + if follow_official_norm: + if i == 0: + params["use_weight_norm"] = False + params["use_spectral_norm"] = True + else: + params["use_weight_norm"] = True + params["use_spectral_norm"] = False + self.discriminators += [HiFiGANScaleDiscriminator(**params)] + self.pooling = None + if scales > 1: + self.pooling = getattr(torch.nn, downsample_pooling)( + **downsample_pooling_params + ) + + def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[List[torch.Tensor]]: List of list of each discriminator outputs, + which consists of eachlayer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + if self.pooling is not None: + x = self.pooling(x) + + return outs + + +class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module): + """HiFi-GAN multi-scale + multi-period discriminator module.""" + + def __init__( + self, + # Multi-scale discriminator related + scales: int = 3, + scale_downsample_pooling: str = "AvgPool1d", + scale_downsample_pooling_params: Dict[str, Any] = { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + scale_discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + follow_official_norm: bool = True, + # Multi-period discriminator related + periods: List[int] = [2, 3, 5, 7, 11], + period_discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + ): + """Initilize HiFiGAN multi-scale + multi-period discriminator module. + + Args: + scales (int): Number of multi-scales. + scale_downsample_pooling (str): Pooling module name for downsampling of the + inputs. + scale_downsample_pooling_params (dict): Parameters for the above pooling + module. + scale_discriminator_params (dict): Parameters for hifi-gan scale + discriminator module. + follow_official_norm (bool): Whether to follow the norm setting of the + official implementaion. The first discriminator uses spectral norm and + the other discriminators use weight norm. + periods (list): List of periods. + period_discriminator_params (dict): Parameters for hifi-gan period + discriminator module. The period parameter will be overwritten. + + """ + super().__init__() + self.msd = HiFiGANMultiScaleDiscriminator( + scales=scales, + downsample_pooling=scale_downsample_pooling, + downsample_pooling_params=scale_downsample_pooling_params, + discriminator_params=scale_discriminator_params, + follow_official_norm=follow_official_norm, + ) + self.mpd = HiFiGANMultiPeriodDiscriminator( + periods=periods, + discriminator_params=period_discriminator_params, + ) + + def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[List[Tensor]]: List of list of each discriminator outputs, + which consists of each layer output tensors. Multi scale and + multi period ones are concatenated. + + """ + msd_outs = self.msd(x) + mpd_outs = self.mpd(x) + return msd_outs + mpd_outs diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py new file mode 100755 index 000000000..91a35e360 --- /dev/null +++ b/egs/ljspeech/TTS/vits/infer.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script performs model inference on test set. + +Usage: +./vits/infer.py \ + --epoch 1000 \ + --exp-dir ./vits/exp \ + --max-duration 500 +""" + + +import argparse +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import List + +import k2 +import torch +import torch.nn as nn +import torchaudio + +from train import get_model, get_params +from tokenizer import Tokenizer + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger +from tts_datamodule import LJSpeechTtsDataModule + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + tokenizer: Tokenizer, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + # Background worker save audios to disk. + def _save_worker( + batch_size: int, + cut_ids: List[str], + audio: torch.Tensor, + audio_pred: torch.Tensor, + audio_lens: List[int], + audio_lens_pred: List[int], + ): + for i in range(batch_size): + torchaudio.save( + str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"), + audio[i:i + 1, :audio_lens[i]], + sample_rate=params.sampling_rate, + ) + torchaudio.save( + str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"), + audio_pred[i:i + 1, :audio_lens_pred[i]], + sample_rate=params.sampling_rate, + ) + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + tokens = batch["tokens"] + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens) + audio_pred = audio_pred.detach().cpu() + # convert to samples + audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + + futures.append( + executor.submit( + _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred + ) + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + # return results + for f in futures: + f.result() + + +@torch.no_grad() +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model.to(device) + model.eval() + + num_param_g = sum([p.numel() for p in model.generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in model.discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + # we need cut ids to display recognition results. + args.return_cuts = True + ljspeech = LJSpeechTtsDataModule(args) + + test_cuts = ljspeech.test_cuts() + test_dl = ljspeech.test_dataloaders(test_cuts) + + infer_dataset( + dl=test_dl, + params=params, + model=model, + tokenizer=tokenizer, + ) + + logging.info(f"Wav files are saved to {params.save_wav_dir}") + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/vits/loss.py b/egs/ljspeech/TTS/vits/loss.py new file mode 100644 index 000000000..21aaad6e7 --- /dev/null +++ b/egs/ljspeech/TTS/vits/loss.py @@ -0,0 +1,336 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""HiFiGAN-related loss modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +from typing import List, Tuple, Union + +import torch +import torch.distributions as D +import torch.nn.functional as F + +from lhotse.features.kaldi import Wav2LogFilterBank + + +class GeneratorAdversarialLoss(torch.nn.Module): + """Generator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "mse", + ): + """Initialize GeneratorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.criterion = self._mse_loss + else: + self.criterion = self._hinge_loss + + def forward( + self, + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> torch.Tensor: + """Calcualate generator adversarial loss. + + Args: + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs.. + + Returns: + Tensor: Generator adversarial loss value. + + """ + if isinstance(outputs, (tuple, list)): + adv_loss = 0.0 + for i, outputs_ in enumerate(outputs): + if isinstance(outputs_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_ = outputs_[-1] + adv_loss += self.criterion(outputs_) + if self.average_by_discriminators: + adv_loss /= i + 1 + else: + adv_loss = self.criterion(outputs) + + return adv_loss + + def _mse_loss(self, x): + return F.mse_loss(x, x.new_ones(x.size())) + + def _hinge_loss(self, x): + return -x.mean() + + +class DiscriminatorAdversarialLoss(torch.nn.Module): + """Discriminator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "mse", + ): + """Initialize DiscriminatorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.fake_criterion = self._mse_fake_loss + self.real_criterion = self._mse_real_loss + else: + self.fake_criterion = self._hinge_fake_loss + self.real_criterion = self._hinge_real_loss + + def forward( + self, + outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calcualate discriminator adversarial loss. + + Args: + outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from generator. + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from groundtruth. + + Returns: + Tensor: Discriminator real loss value. + Tensor: Discriminator fake loss value. + + """ + if isinstance(outputs, (tuple, list)): + real_loss = 0.0 + fake_loss = 0.0 + for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): + if isinstance(outputs_hat_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_hat_ = outputs_hat_[-1] + outputs_ = outputs_[-1] + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + if self.average_by_discriminators: + fake_loss /= i + 1 + real_loss /= i + 1 + else: + real_loss = self.real_criterion(outputs) + fake_loss = self.fake_criterion(outputs_hat) + + return real_loss, fake_loss + + def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_ones(x.size())) + + def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_zeros(x.size())) + + def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) + + def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) + + +class FeatureMatchLoss(torch.nn.Module): + """Feature matching loss module.""" + + def __init__( + self, + average_by_layers: bool = True, + average_by_discriminators: bool = True, + include_final_outputs: bool = False, + ): + """Initialize FeatureMatchLoss module. + + Args: + average_by_layers (bool): Whether to average the loss by the number + of layers. + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + include_final_outputs (bool): Whether to include the final output of + each discriminator for loss calculation. + + """ + super().__init__() + self.average_by_layers = average_by_layers + self.average_by_discriminators = average_by_discriminators + self.include_final_outputs = include_final_outputs + + def forward( + self, + feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], + feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], + ) -> torch.Tensor: + """Calculate feature matching loss. + + Args: + feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from generator's outputs. + feats (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from groundtruth.. + + Returns: + Tensor: Feature matching loss value. + + """ + feat_match_loss = 0.0 + for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): + feat_match_loss_ = 0.0 + if not self.include_final_outputs: + feats_hat_ = feats_hat_[:-1] + feats_ = feats_[:-1] + for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): + feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) + if self.average_by_layers: + feat_match_loss_ /= j + 1 + feat_match_loss += feat_match_loss_ + if self.average_by_discriminators: + feat_match_loss /= i + 1 + + return feat_match_loss + + +class MelSpectrogramLoss(torch.nn.Module): + """Mel-spectrogram loss.""" + + def __init__( + self, + sampling_rate: int = 22050, + frame_length: int = 1024, # in samples + frame_shift: int = 256, # in samples + n_mels: int = 80, + use_fft_mag: bool = True, + ): + super().__init__() + self.wav_to_mel = Wav2LogFilterBank( + sampling_rate=sampling_rate, + frame_length=frame_length / sampling_rate, # in second + frame_shift=frame_shift / sampling_rate, # in second + use_fft_mag=use_fft_mag, + num_filters=n_mels, + ) + + def forward( + self, + y_hat: torch.Tensor, + y: torch.Tensor, + return_mel: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: + """Calculate Mel-spectrogram loss. + + Args: + y_hat (Tensor): Generated waveform tensor (B, 1, T). + y (Tensor): Groundtruth waveform tensor (B, 1, T). + spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor + (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth + waveform. + + Returns: + Tensor: Mel-spectrogram loss value. + + """ + mel_hat = self.wav_to_mel(y_hat.squeeze(1)) + mel = self.wav_to_mel(y.squeeze(1)) + mel_loss = F.l1_loss(mel_hat, mel) + + if return_mel: + return mel_loss, (mel_hat, mel) + + return mel_loss + + +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py + +"""VITS-related loss modules. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + + +class KLDivergenceLoss(torch.nn.Module): + """KL divergence loss.""" + + def forward( + self, + z_p: torch.Tensor, + logs_q: torch.Tensor, + m_p: torch.Tensor, + logs_p: torch.Tensor, + z_mask: torch.Tensor, + ) -> torch.Tensor: + """Calculate KL divergence loss. + + Args: + z_p (Tensor): Flow hidden representation (B, H, T_feats). + logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). + m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). + logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). + z_mask (Tensor): Mask tensor (B, 1, T_feats). + + Returns: + Tensor: KL divergence loss. + + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + loss = kl / torch.sum(z_mask) + + return loss + + +class KLDivergenceLossWithoutFlow(torch.nn.Module): + """KL divergence loss without flow.""" + + def forward( + self, + m_q: torch.Tensor, + logs_q: torch.Tensor, + m_p: torch.Tensor, + logs_p: torch.Tensor, + ) -> torch.Tensor: + """Calculate KL divergence loss without flow. + + Args: + m_q (Tensor): Posterior encoder projected mean (B, H, T_feats). + logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). + m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). + logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). + """ + posterior_norm = D.Normal(m_q, torch.exp(logs_q)) + prior_norm = D.Normal(m_p, torch.exp(logs_p)) + loss = D.kl_divergence(posterior_norm, prior_norm).mean() + return loss diff --git a/egs/ljspeech/TTS/vits/monotonic_align/__init__.py b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py new file mode 100644 index 000000000..2b35654f5 --- /dev/null +++ b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py @@ -0,0 +1,81 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py + +"""Maximum path calculation module. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import warnings + +import numpy as np +import torch +from numba import njit, prange + +try: + from .core import maximum_path_c + + is_cython_avalable = True +except ImportError: + is_cython_avalable = False + warnings.warn( + "Cython version is not available. Fallback to 'EXPERIMETAL' numba version. " + "If you want to use the cython version, please build it as follows: " + "`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`" + ) + + +def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + """Calculate maximum path. + + Args: + neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text). + attn_mask (Tensor): Attention mask (B, T_feats, T_text). + + Returns: + Tensor: Maximum path tensor (B, T_feats, T_text). + + """ + device, dtype = neg_x_ent.device, neg_x_ent.dtype + neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32) + path = np.zeros(neg_x_ent.shape, dtype=np.int32) + t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32) + t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32) + if is_cython_avalable: + maximum_path_c(path, neg_x_ent, t_t_max, t_s_max) + else: + maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max) + + return torch.from_numpy(path).to(device=device, dtype=dtype) + + +@njit +def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf): + """Calculate a single maximum path with numba.""" + index = t_x - 1 + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 + + +@njit(parallel=True) +def maximum_path_numba(paths, values, t_ys, t_xs): + """Calculate batch maximum path with numba.""" + for i in prange(paths.shape[0]): + maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/egs/ljspeech/TTS/vits/monotonic_align/core.pyx b/egs/ljspeech/TTS/vits/monotonic_align/core.pyx new file mode 100644 index 000000000..c02c2d02e --- /dev/null +++ b/egs/ljspeech/TTS/vits/monotonic_align/core.pyx @@ -0,0 +1,51 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx + +"""Maximum path calculation module with cython optimization. + +This code is copied from https://github.com/jaywalnut310/vits and modifed code format. + +""" + +cimport cython + +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil: + cdef int b = paths.shape[0] + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/egs/ljspeech/TTS/vits/monotonic_align/setup.py b/egs/ljspeech/TTS/vits/monotonic_align/setup.py new file mode 100644 index 000000000..33d75e176 --- /dev/null +++ b/egs/ljspeech/TTS/vits/monotonic_align/setup.py @@ -0,0 +1,31 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py +"""Setup cython code.""" + +from Cython.Build import cythonize +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext as _build_ext + + +class build_ext(_build_ext): + """Overwrite build_ext.""" + + def finalize_options(self): + """Prevent numpy from thinking it is still in its setup process.""" + _build_ext.finalize_options(self) + __builtins__.__NUMPY_SETUP__ = False + import numpy + + self.include_dirs.append(numpy.get_include()) + + +exts = [ + Extension( + name="core", + sources=["core.pyx"], + ) +] +setup( + name="monotonic_align", + ext_modules=cythonize(exts, language_level=3), + cmdclass={"build_ext": build_ext}, +) diff --git a/egs/ljspeech/TTS/vits/posterior_encoder.py b/egs/ljspeech/TTS/vits/posterior_encoder.py new file mode 100644 index 000000000..6b8a5be52 --- /dev/null +++ b/egs/ljspeech/TTS/vits/posterior_encoder.py @@ -0,0 +1,117 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Posterior encoder module in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +from typing import Optional, Tuple + +import torch + +from icefall.utils import make_pad_mask +from wavenet import WaveNet, Conv1d + + +class PosteriorEncoder(torch.nn.Module): + """Posterior encoder module in VITS. + + This is a module of posterior encoder described in `Conditional Variational + Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + """ + + def __init__( + self, + in_channels: int = 513, + out_channels: int = 192, + hidden_channels: int = 192, + kernel_size: int = 5, + layers: int = 16, + stacks: int = 1, + base_dilation: int = 1, + global_channels: int = -1, + dropout_rate: float = 0.0, + bias: bool = True, + use_weight_norm: bool = True, + ): + """Initilialize PosteriorEncoder module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size in WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of repeat stacking of WaveNet. + base_dilation (int): Base dilation factor. + global_channels (int): Number of global conditioning channels. + dropout_rate (float): Dropout rate. + bias (bool): Whether to use bias parameters in conv. + use_weight_norm (bool): Whether to apply weight norm. + + """ + super().__init__() + + # define modules + self.input_conv = Conv1d(in_channels, hidden_channels, 1) + self.encoder = WaveNet( + in_channels=-1, + out_channels=-1, + kernel_size=kernel_size, + layers=layers, + stacks=stacks, + base_dilation=base_dilation, + residual_channels=hidden_channels, + aux_channels=-1, + gate_channels=hidden_channels * 2, + skip_channels=hidden_channels, + global_channels=global_channels, + dropout_rate=dropout_rate, + bias=bias, + use_weight_norm=use_weight_norm, + use_first_conv=False, + use_last_conv=False, + scale_residual=False, + scale_skip_connect=True, + ) + self.proj = Conv1d(hidden_channels, out_channels * 2, 1) + + def forward( + self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T_feats). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Encoded hidden representation tensor (B, out_channels, T_feats). + Tensor: Projected mean tensor (B, out_channels, T_feats). + Tensor: Projected scale tensor (B, out_channels, T_feats). + Tensor: Mask tensor for input tensor (B, 1, T_feats). + + """ + x_mask = ( + (~make_pad_mask(x_lengths)) + .unsqueeze(1) + .to( + dtype=x.dtype, + device=x.device, + ) + ) + x = self.input_conv(x) * x_mask + x = self.encoder(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = stats.split(stats.size(1) // 2, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + + return z, m, logs, x_mask diff --git a/egs/ljspeech/TTS/vits/residual_coupling.py b/egs/ljspeech/TTS/vits/residual_coupling.py new file mode 100644 index 000000000..2d6807cb7 --- /dev/null +++ b/egs/ljspeech/TTS/vits/residual_coupling.py @@ -0,0 +1,229 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Residual affine coupling modules in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +from typing import Optional, Tuple, Union + +import torch + +from flow import FlipFlow +from wavenet import WaveNet + + +class ResidualAffineCouplingBlock(torch.nn.Module): + """Residual affine coupling block module. + + This is a module of residual affine coupling block, which used as "Flow" in + `Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + + """ + + def __init__( + self, + in_channels: int = 192, + hidden_channels: int = 192, + flows: int = 4, + kernel_size: int = 5, + base_dilation: int = 1, + layers: int = 4, + global_channels: int = -1, + dropout_rate: float = 0.0, + use_weight_norm: bool = True, + bias: bool = True, + use_only_mean: bool = True, + ): + """Initilize ResidualAffineCouplingBlock module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + flows (int): Number of flows. + kernel_size (int): Kernel size for WaveNet. + base_dilation (int): Base dilation factor for WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of stacks of WaveNet. + global_channels (int): Number of global channels. + dropout_rate (float): Dropout rate. + use_weight_norm (bool): Whether to use weight normalization in WaveNet. + bias (bool): Whether to use bias paramters in WaveNet. + use_only_mean (bool): Whether to estimate only mean. + + """ + super().__init__() + + self.flows = torch.nn.ModuleList() + for i in range(flows): + self.flows += [ + ResidualAffineCouplingLayer( + in_channels=in_channels, + hidden_channels=hidden_channels, + kernel_size=kernel_size, + base_dilation=base_dilation, + layers=layers, + stacks=1, + global_channels=global_channels, + dropout_rate=dropout_rate, + use_weight_norm=use_weight_norm, + bias=bias, + use_only_mean=use_only_mean, + ) + ] + self.flows += [FlipFlow()] + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, in_channels, T). + + """ + if not inverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, inverse=inverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, inverse=inverse) + return x + + +class ResidualAffineCouplingLayer(torch.nn.Module): + """Residual affine coupling layer.""" + + def __init__( + self, + in_channels: int = 192, + hidden_channels: int = 192, + kernel_size: int = 5, + base_dilation: int = 1, + layers: int = 5, + stacks: int = 1, + global_channels: int = -1, + dropout_rate: float = 0.0, + use_weight_norm: bool = True, + bias: bool = True, + use_only_mean: bool = True, + ): + """Initialzie ResidualAffineCouplingLayer module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size for WaveNet. + base_dilation (int): Base dilation factor for WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of stacks of WaveNet. + global_channels (int): Number of global channels. + dropout_rate (float): Dropout rate. + use_weight_norm (bool): Whether to use weight normalization in WaveNet. + bias (bool): Whether to use bias paramters in WaveNet. + use_only_mean (bool): Whether to estimate only mean. + + """ + assert in_channels % 2 == 0, "in_channels should be divisible by 2" + super().__init__() + self.half_channels = in_channels // 2 + self.use_only_mean = use_only_mean + + # define modules + self.input_conv = torch.nn.Conv1d( + self.half_channels, + hidden_channels, + 1, + ) + self.encoder = WaveNet( + in_channels=-1, + out_channels=-1, + kernel_size=kernel_size, + layers=layers, + stacks=stacks, + base_dilation=base_dilation, + residual_channels=hidden_channels, + aux_channels=-1, + gate_channels=hidden_channels * 2, + skip_channels=hidden_channels, + global_channels=global_channels, + dropout_rate=dropout_rate, + bias=bias, + use_weight_norm=use_weight_norm, + use_first_conv=False, + use_last_conv=False, + scale_residual=False, + scale_skip_connect=True, + ) + if use_only_mean: + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels, + 1, + ) + else: + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels * 2, + 1, + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, in_channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + xa, xb = x.split(x.size(1) // 2, dim=1) + h = self.input_conv(xa) * x_mask + h = self.encoder(h, x_mask, g=g) + stats = self.proj(h) * x_mask + if not self.use_only_mean: + m, logs = stats.split(stats.size(1) // 2, dim=1) + else: + m = stats + logs = torch.zeros_like(m) + + if not inverse: + xb = m + xb * torch.exp(logs) * x_mask + x = torch.cat([xa, xb], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + xb = (xb - m) * torch.exp(-logs) * x_mask + x = torch.cat([xa, xb], 1) + return x diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py new file mode 100755 index 000000000..8acca7c02 --- /dev/null +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is used to test the exported onnx model by vits/export-onnx.py + +Use the onnx model to generate a wav: +./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt +""" + + +import argparse +import logging +import onnxruntime as ort +import torch +import torchaudio + +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the onnx model.", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +class OnnxModel: + def __init__(self, model_filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + + def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: + """ + Args: + tokens: + A 1-D tensor of shape (1, T) + Returns: + A tensor of shape (1, T') + """ + noise_scale = torch.tensor([0.667], dtype=torch.float32) + noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) + alpha = torch.tensor([1.0], dtype=torch.float32) + + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: tokens.numpy(), + self.model.get_inputs()[1].name: tokens_lens.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: noise_scale_dur.numpy(), + self.model.get_inputs()[4].name: alpha.numpy(), + }, + )[0] + return torch.from_numpy(out) + + +def main(): + args = get_parser().parse_args() + + tokenizer = Tokenizer(args.tokens) + + logging.info("About to create onnx model") + model = OnnxModel(args.model_filename) + + text = "I went there to see the land, the people and how their system works, end quote." + tokens = tokenizer.texts_to_token_ids([text]) + tokens = torch.tensor(tokens) # (1, T) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) + audio = model(tokens, tokens_lens) # (1, T') + + torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) + logging.info("Saved to test_onnx.wav") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py new file mode 100644 index 000000000..9f337e45b --- /dev/null +++ b/egs/ljspeech/TTS/vits/text_encoder.py @@ -0,0 +1,662 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Text encoder module in VITS. + +This code is based on + - https://github.com/jaywalnut310/vits + - https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py + - https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py +""" + +import copy +import math +from typing import Optional, Tuple + +import torch +from torch import Tensor, nn + +from icefall.utils import is_jit_tracing, make_pad_mask + + +class TextEncoder(torch.nn.Module): + """Text encoder module in VITS. + + This is a module of text encoder described in `Conditional Variational Autoencoder + with Adversarial Learning for End-to-End Text-to-Speech`. + """ + + def __init__( + self, + vocabs: int, + d_model: int = 192, + num_heads: int = 2, + dim_feedforward: int = 768, + cnn_module_kernel: int = 5, + num_layers: int = 6, + dropout: float = 0.1, + ): + """Initialize TextEncoder module. + + Args: + vocabs (int): Vocabulary size. + d_model (int): attention dimension + num_heads (int): number of attention heads + dim_feedforward (int): feedforward dimention + cnn_module_kernel (int): convolution kernel size + num_layers (int): number of encoder layers + dropout (float): dropout rate + """ + super().__init__() + self.d_model = d_model + + # define modules + self.emb = torch.nn.Embedding(vocabs, d_model) + torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5) + + # We use conformer as text encoder + self.encoder = Transformer( + d_model=d_model, + num_heads=num_heads, + dim_feedforward=dim_feedforward, + cnn_module_kernel=cnn_module_kernel, + num_layers=num_layers, + dropout=dropout, + ) + + self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1) + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input index tensor (B, T_text). + x_lengths (Tensor): Length tensor (B,). + + Returns: + Tensor: Encoded hidden representation (B, attention_dim, T_text). + Tensor: Projected mean tensor (B, attention_dim, T_text). + Tensor: Projected scale tensor (B, attention_dim, T_text). + Tensor: Mask tensor for input tensor (B, 1, T_text). + + """ + # (B, T_text, embed_dim) + x = self.emb(x) * math.sqrt(self.d_model) + + assert x.size(1) == x_lengths.max().item() + + # (B, T_text) + pad_mask = make_pad_mask(x_lengths) + + # encoder assume the channel last (B, T_text, embed_dim) + x = self.encoder(x, key_padding_mask=pad_mask) + + # convert the channel first (B, embed_dim, T_text) + x = x.transpose(1, 2) + non_pad_mask = (~pad_mask).unsqueeze(1) + stats = self.proj(x) * non_pad_mask + m, logs = stats.split(stats.size(1) // 2, dim=1) + + return x, m, logs, non_pad_mask + + +class Transformer(nn.Module): + """ + Args: + d_model (int): attention dimension + num_heads (int): number of attention heads + dim_feedforward (int): feedforward dimention + cnn_module_kernel (int): convolution kernel size + num_layers (int): number of encoder layers + dropout (float): dropout rate + """ + + def __init__( + self, + d_model: int = 192, + num_heads: int = 2, + dim_feedforward: int = 768, + cnn_module_kernel: int = 5, + num_layers: int = 6, + dropout: float = 0.1, + ) -> None: + super().__init__() + + self.num_layers = num_layers + self.d_model = d_model + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + num_heads=num_heads, + dim_feedforward=dim_feedforward, + cnn_module_kernel=cnn_module_kernel, + dropout=dropout, + ) + self.encoder = TransformerEncoder(encoder_layer, num_layers) + self.after_norm = nn.LayerNorm(d_model) + + def forward( + self, x: Tensor, key_padding_mask: Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + lengths: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + """ + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + x = self.encoder( + x, pos_emb, key_padding_mask=key_padding_mask + ) # (T, N, C) + + x = self.after_norm(x) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + return x + + +class TransformerEncoderLayer(nn.Module): + """ + TransformerEncoderLayer is made up of self-attn and feedforward. + + Args: + d_model: the number of expected features in the input. + num_heads: the number of heads in the multi-head attention models. + dim_feedforward: the dimension of the feed-forward network model. + dropout: the dropout value (default=0.1). + """ + + def __init__( + self, + d_model: int, + num_heads: int, + dim_feedforward: int, + cnn_module_kernel: int, + dropout: float = 0.1, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + + self.ff_scale = 0.5 + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the transformer encoder layer. + + Args: + src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim). + pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). + key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) + """ + # macaron style feed-forward module + src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src))) + + # multi-head self-attention module + src_attn = self.self_attn( + self.norm_mha(src), + pos_emb=pos_emb, + key_padding_mask=key_padding_mask, + ) + src = src + self.dropout(src_attn) + + # convolution module + src = src + self.dropout(self.conv_module(self.norm_conv(src))) + + # feed-forward module + src = src + self.dropout(self.feed_forward(self.norm_ff(src))) + + src = self.norm_final(src) + + return src + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the TransformerEncoderLayer class. + num_layers: the number of sub-encoder-layers in the encoder. + """ + + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim). + pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). + key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) + """ + output = src + + for layer_index, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + key_padding_mask=key_padding_mask, + ) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + x_size = x.size(1) + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, seq_len, 2*seq_len-1). + + Returns: + Tensor: tensor of shape (batch, head, seq_len, seq_len) + """ + (batch_size, num_heads, seq_len, n) = x.shape + + if not is_jit_tracing(): + assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1" + + if is_jit_tracing(): + rows = torch.arange(start=seq_len - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, seq_len, seq_len) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, seq_len, seq_len), + (batch_stride, head_stride, time_stride - n_stride, n_stride), + storage_offset=n_stride * (seq_len - 1), + ) + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x: Input tensor of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim) + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + Its shape is (batch_size, seq_len). + + Outputs: + A tensor of shape (seq_len, batch_size, embed_dim). + """ + seq_len, batch_size, _ = x.shape + scaling = float(self.head_dim) ** -0.5 + + q, k, v = self.in_proj(x).chunk(3, dim=-1) + + q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) + k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) + v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + + q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim) + + p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim) + # (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1) + p = p.permute(0, 2, 3, 1) + + # (batch_size, num_head, seq_len, head_dim) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len) + + # compute matrix b and matrix d + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1) + matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len) + + # (batch_size, num_head, seq_len, seq_len) + attn_output_weights = (matrix_ac + matrix_bd) * scaling + attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, seq_len) + attn_output_weights = attn_output_weights.view( + batch_size, self.num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + batch_size * self.num_heads, seq_len, seq_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=self.dropout, training=self.training + ) + + # (batch_size * num_head, seq_len, head_dim) + attn_output = torch.bmm(attn_output_weights, v) + assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim) + + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim) + ) + # (seq_len, batch_size, embed_dim) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + padding = (kernel_size - 1) // 2 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = self.depthwise_conv(x) + # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Swish(nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def _test_text_encoder(): + vocabs = 500 + d_model = 192 + batch_size = 5 + seq_len = 100 + + m = TextEncoder(vocabs=vocabs, d_model=d_model) + x, m, logs, mask = m( + x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)), + x_lengths=torch.full((batch_size,), seq_len), + ) + print(x.shape, m.shape, logs.shape, mask.shape) + + +if __name__ == "__main__": + _test_text_encoder() diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py new file mode 100644 index 000000000..0678b26fe --- /dev/null +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -0,0 +1,106 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List + +import g2p_en +import tacotron_cleaner.cleaners +from utils import intersperse + + +class Tokenizer(object): + def __init__(self, tokens: str): + """ + Args: + tokens: the file that maps tokens to ids + """ + # Parse token file + self.token2id: Dict[str, int] = {} + with open(tokens, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + id = int(info[0]) + else: + token, id = info[0], int(info[1]) + self.token2id[token] = id + + self.blank_id = self.token2id[""] + self.oov_id = self.token2id[""] + self.vocab_size = len(self.token2id) + + self.g2p = g2p_en.G2p() + + def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True): + """ + Args: + texts: + A list of transcripts. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for text in texts: + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + tokens = self.g2p(text) + token_ids = [] + for t in tokens: + if t in self.token2id: + token_ids.append(self.token2id[t]) + else: + token_ids.append(self.oov_id) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.blank_id) + + token_ids_list.append(token_ids) + + return token_ids_list + + def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True): + """ + Args: + tokens_list: + A list of token list, each corresponding to one utterance. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for tokens in tokens_list: + token_ids = [] + for t in tokens: + if t in self.token2id: + token_ids.append(self.token2id[t]) + else: + token_ids.append(self.oov_id) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.blank_id) + token_ids_list.append(token_ids) + + return token_ids_list diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py new file mode 100755 index 000000000..eb43a4cc9 --- /dev/null +++ b/egs/ljspeech/TTS/vits/train.py @@ -0,0 +1,893 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import numpy as np +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch.optim import Optimizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool + +from tokenizer import Tokenizer +from tts_datamodule import LJSpeechTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--lr", type=float, default=2.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=20, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 200, + "env_info": get_env_info(), + "sampling_rate": 22050, + "frame_shift": 256, + "frame_length": 1024, + "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length + "n_mels": 80, + "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss + "lambda_mel": 45.0, # loss scaling coefficient for Mel loss + "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss + "lambda_dur": 1.0, # loss scaling coefficient for duration loss + "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + mel_loss_params = { + "n_mels": params.n_mels, + "frame_length": params.frame_length, + "frame_shift": params.frame_shift, + } + model = VITS( + vocab_size=params.vocab_size, + feature_dim=params.feature_dim, + sampling_rate=params.sampling_rate, + mel_loss_params=mel_loss_params, + lambda_adv=params.lambda_adv, + lambda_mel=params.lambda_mel, + lambda_feat_match=params.lambda_feat_match, + lambda_dur=params.lambda_dur, + lambda_kl=params.lambda_kl, + ) + return model + + +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): + """Parse batch data""" + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + + return audio, audio_lens, features, features_lens, tokens, tokens_lens + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + tokenizer: + Used to convert text to phonemes. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["tokens"]) + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) + + loss_info = MetricsTracker() + loss_info['samples'] = batch_size + + try: + with autocast(enabled=params.use_fp16): + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(loss_d).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + return_sample=params.batch_idx_train % params.log_interval == 0, + ) + for k, v in stats_g.items(): + if "returned_sample" not in k: + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(loss_g).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + if "returned_sample" in stats_g: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + tb_writer.add_audio( + "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_audio( + "train/speech_", speech_, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_image( + "train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC' + ) + tb_writer.add_image( + "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC' + ) + + if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info, (speech_hat, speech) = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + tb_writer.add_audio( + "train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_audio( + "train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + returned_sample = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["tokens"]) + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) + + loss_info = MetricsTracker() + loss_info['samples'] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + audio_pred, _, duration = inner_model.inference( + text=tokens[0, :tokens_lens[0].item()] + ) + audio_pred = audio_pred.data.cpu().numpy() + audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred)) + audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy() + returned_sample = (audio_pred, audio_gt) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, returned_sample + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + tokenizer: Tokenizer, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + generator = model.generator + discriminator = model.discriminator + + num_param_g = sum([p.numel() for p in generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + optimizer_d = torch.optim.AdamW( + discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + ljspeech = LJSpeechTtsDataModule(args) + + train_cuts = ljspeech.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = ljspeech.train_dataloaders(train_cuts) + + valid_cuts = ljspeech.valid_cuts() + valid_dl = ljspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + # step per epoch + scheduler_g.step() + scheduler_d.step() + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/vits/transform.py b/egs/ljspeech/TTS/vits/transform.py new file mode 100644 index 000000000..c20d13130 --- /dev/null +++ b/egs/ljspeech/TTS/vits/transform.py @@ -0,0 +1,218 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py + +"""Flow-related transformation. + +This code is derived from https://github.com/bayesiains/nflows. + +""" + +import numpy as np +import torch +from torch.nn import functional as F + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +# TODO(kan-bayashi): Documentation and type hint +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs + ) + return outputs, logabsdet + + +# TODO(kan-bayashi): Documentation and type hint +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + ( + outputs[inside_interval_mask], + logabsdet[inside_interval_mask], + ) = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +# TODO(kan-bayashi): Documentation and type hint +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = _searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = _searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * ( + input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta + ) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet + + +def _searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py new file mode 100644 index 000000000..0fcbb92c1 --- /dev/null +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -0,0 +1,325 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + SpeechSynthesisDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LJSpeechTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" + ) diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py new file mode 100644 index 000000000..2a3dae900 --- /dev/null +++ b/egs/ljspeech/TTS/vits/utils.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Optional, Tuple, Union +import collections +import logging + +import torch +import torch.nn as nn +import torch.distributed as dist +from lhotse.dataset.sampling.base import CutSampler +from pathlib import Path +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter + + +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py +def get_random_segments( + x: torch.Tensor, + x_lengths: torch.Tensor, + segment_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get random segments. + + Args: + x (Tensor): Input tensor (B, C, T). + x_lengths (Tensor): Length tensor (B,). + segment_size (int): Segment size. + + Returns: + Tensor: Segmented tensor (B, C, segment_size). + Tensor: Start index tensor (B,). + + """ + b, c, t = x.size() + max_start_idx = x_lengths - segment_size + max_start_idx[max_start_idx < 0] = 0 + start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to( + dtype=torch.long, + ) + segments = get_segments(x, start_idxs, segment_size) + + return segments, start_idxs + + +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py +def get_segments( + x: torch.Tensor, + start_idxs: torch.Tensor, + segment_size: int, +) -> torch.Tensor: + """Get segments. + + Args: + x (Tensor): Input tensor (B, C, T). + start_idxs (Tensor): Start index tensor (B,). + segment_size (int): Segment size. + + Returns: + Tensor: Segmented tensor (B, C, segment_size). + + """ + b, c, t = x.size() + segments = x.new_zeros(b, c, segment_size) + for i, start_idx in enumerate(start_idxs): + segments[i] = x[i, :, start_idx : start_idx + segment_size] + return segments + + +# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py +def intersperse(sequence, item=0): + result = [item] * (len(sequence) * 2 + 1) + result[1::2] = sequence + return result + + +# from https://github.com/jaywalnut310/vits/blob/main/utils.py +MATPLOTLIB_FLAG = False + + +def plot_feature(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +class MetricsTracker(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + # This class will play a role as metrics tracker. + # It can record many metrics, including but not limited to loss. + super(MetricsTracker, self).__init__(int) + + def __add__(self, other: "MetricsTracker") -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + def __str__(self) -> str: + ans = "" + for k, v in self.norm_items(): + norm_value = "%.4g" % v + ans += str(k) + "=" + str(norm_value) + ", " + samples = "%.2f" % self["samples"] + ans += "over " + str(samples) + " samples." + return ans + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('loss_1', 0.1), ('loss_2', 0.07)] + """ + samples = self["samples"] if "samples" in self else 1 + ans = [] + for k, v in self.items(): + if k == "samples": + continue + norm_value = float(v) / samples + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([float(self[k]) for k in keys], device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary( + self, + tb_writer: SummaryWriter, + prefix: str, + batch_idx: int, + ) -> None: + """Add logging information to a TensorBoard writer. + + Args: + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx) + + +# checkpoint saving and loading +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def save_checkpoint( + filename: Path, + model: Union[nn.Module, DDP], + params: Optional[Dict[str, Any]] = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRSchedulerType] = None, + scheduler_d: Optional[LRSchedulerType] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +) -> None: + """Save training information to a file. + + Args: + filename: + The checkpoint filename. + model: + The model to be saved. We only save its `state_dict()`. + model_avg: + The stored model averaged from the start of training. + params: + User defined parameters, e.g., epoch, loss. + optimizer_g: + The optimizer for generator used in the training. + Its `state_dict` will be saved. + optimizer_d: + The optimizer for discriminator used in the training. + Its `state_dict` will be saved. + scheduler_g: + The learning rate scheduler for generator used in the training. + Its `state_dict` will be saved. + scheduler_d: + The learning rate scheduler for discriminator used in the training. + Its `state_dict` will be saved. + scalar: + The GradScaler to be saved. We only save its `state_dict()`. + rank: + Used in DDP. We save checkpoint only for the node whose rank is 0. + Returns: + Return None. + """ + if rank != 0: + return + + logging.info(f"Saving checkpoint to {filename}") + + if isinstance(model, DDP): + model = model.module + + checkpoint = { + "model": model.state_dict(), + "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None, + "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None, + "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None, + "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None, + "grad_scaler": scaler.state_dict() if scaler is not None else None, + "sampler": sampler.state_dict() if sampler is not None else None, + } + + if params: + for k, v in params.items(): + assert k not in checkpoint + checkpoint[k] = v + + torch.save(checkpoint, filename) diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py new file mode 100644 index 000000000..d5e20a578 --- /dev/null +++ b/egs/ljspeech/TTS/vits/vits.py @@ -0,0 +1,610 @@ +# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""VITS module for GAN-TTS task.""" + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +from torch.cuda.amp import autocast + +from hifigan import ( + HiFiGANMultiPeriodDiscriminator, + HiFiGANMultiScaleDiscriminator, + HiFiGANMultiScaleMultiPeriodDiscriminator, + HiFiGANPeriodDiscriminator, + HiFiGANScaleDiscriminator, +) +from loss import ( + DiscriminatorAdversarialLoss, + FeatureMatchLoss, + GeneratorAdversarialLoss, + KLDivergenceLoss, + MelSpectrogramLoss, +) +from utils import get_segments +from generator import VITSGenerator + + +AVAILABLE_GENERATERS = { + "vits_generator": VITSGenerator, +} +AVAILABLE_DISCRIMINATORS = { + "hifigan_period_discriminator": HiFiGANPeriodDiscriminator, + "hifigan_scale_discriminator": HiFiGANScaleDiscriminator, + "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator, + "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator, + "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA +} + + +class VITS(nn.Module): + """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech` + """ + + def __init__( + self, + # generator related + vocab_size: int, + feature_dim: int = 513, + sampling_rate: int = 22050, + generator_type: str = "vits_generator", + generator_params: Dict[str, Any] = { + "hidden_channels": 192, + "spks": None, + "langs": None, + "spk_embed_dim": None, + "global_channels": -1, + "segment_size": 32, + "text_encoder_attention_heads": 2, + "text_encoder_ffn_expand": 4, + "text_encoder_cnn_module_kernel": 5, + "text_encoder_blocks": 6, + "text_encoder_dropout_rate": 0.1, + "decoder_kernel_size": 7, + "decoder_channels": 512, + "decoder_upsample_scales": [8, 8, 2, 2], + "decoder_upsample_kernel_sizes": [16, 16, 4, 4], + "decoder_resblock_kernel_sizes": [3, 7, 11], + "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "use_weight_norm_in_decoder": True, + "posterior_encoder_kernel_size": 5, + "posterior_encoder_layers": 16, + "posterior_encoder_stacks": 1, + "posterior_encoder_base_dilation": 1, + "posterior_encoder_dropout_rate": 0.0, + "use_weight_norm_in_posterior_encoder": True, + "flow_flows": 4, + "flow_kernel_size": 5, + "flow_base_dilation": 1, + "flow_layers": 4, + "flow_dropout_rate": 0.0, + "use_weight_norm_in_flow": True, + "use_only_mean_in_flow": True, + "stochastic_duration_predictor_kernel_size": 3, + "stochastic_duration_predictor_dropout_rate": 0.5, + "stochastic_duration_predictor_flows": 4, + "stochastic_duration_predictor_dds_conv_layers": 3, + }, + # discriminator related + discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator", + discriminator_params: Dict[str, Any] = { + "scales": 1, + "scale_downsample_pooling": "AvgPool1d", + "scale_downsample_pooling_params": { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + "scale_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + "follow_official_norm": False, + "periods": [2, 3, 5, 7, 11], + "period_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + # loss related + generator_adv_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "loss_type": "mse", + }, + discriminator_adv_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "loss_type": "mse", + }, + feat_match_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "average_by_layers": False, + "include_final_outputs": True, + }, + mel_loss_params: Dict[str, Any] = { + "frame_shift": 256, + "frame_length": 1024, + "n_mels": 80, + }, + lambda_adv: float = 1.0, + lambda_mel: float = 45.0, + lambda_feat_match: float = 2.0, + lambda_dur: float = 1.0, + lambda_kl: float = 1.0, + cache_generator_outputs: bool = True, + ): + """Initialize VITS module. + + Args: + idim (int): Input vocabrary size. + odim (int): Acoustic feature dimension. The actual output channels will + be 1 since VITS is the end-to-end text-to-wave model but for the + compatibility odim is used to indicate the acoustic feature dimension. + sampling_rate (int): Sampling rate, not used for the training but it will + be referred in saving waveform during the inference. + generator_type (str): Generator type. + generator_params (Dict[str, Any]): Parameter dict for generator. + discriminator_type (str): Discriminator type. + discriminator_params (Dict[str, Any]): Parameter dict for discriminator. + generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator + adversarial loss. + discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for + discriminator adversarial loss. + feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss. + mel_loss_params (Dict[str, Any]): Parameter dict for mel loss. + lambda_adv (float): Loss scaling coefficient for adversarial loss. + lambda_mel (float): Loss scaling coefficient for mel spectrogram loss. + lambda_feat_match (float): Loss scaling coefficient for feat match loss. + lambda_dur (float): Loss scaling coefficient for duration loss. + lambda_kl (float): Loss scaling coefficient for KL divergence loss. + cache_generator_outputs (bool): Whether to cache generator outputs. + + """ + super().__init__() + + # define modules + generator_class = AVAILABLE_GENERATERS[generator_type] + if generator_type == "vits_generator": + # NOTE(kan-bayashi): Update parameters for the compatibility. + # The idim and odim is automatically decided from input data, + # where idim represents #vocabularies and odim represents + # the input acoustic feature dimension. + generator_params.update(vocabs=vocab_size, aux_channels=feature_dim) + self.generator = generator_class( + **generator_params, + ) + discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] + self.discriminator = discriminator_class( + **discriminator_params, + ) + self.generator_adv_loss = GeneratorAdversarialLoss( + **generator_adv_loss_params, + ) + self.discriminator_adv_loss = DiscriminatorAdversarialLoss( + **discriminator_adv_loss_params, + ) + self.feat_match_loss = FeatureMatchLoss( + **feat_match_loss_params, + ) + mel_loss_params.update(sampling_rate=sampling_rate) + self.mel_loss = MelSpectrogramLoss( + **mel_loss_params, + ) + self.kl_loss = KLDivergenceLoss() + + # coefficients + self.lambda_adv = lambda_adv + self.lambda_mel = lambda_mel + self.lambda_kl = lambda_kl + self.lambda_feat_match = lambda_feat_match + self.lambda_dur = lambda_dur + + # cache + self.cache_generator_outputs = cache_generator_outputs + self._cache = None + + # store sampling rate for saving wav file + # (not used for the training) + self.sampling_rate = sampling_rate + + # store parameters for test compatibility + self.spks = self.generator.spks + self.langs = self.generator.langs + self.spk_embed_dim = self.generator.spk_embed_dim + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + return_sample: bool = False, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + forward_generator: bool = True, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Perform generator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + forward_generator (bool): Whether to forward generator. + + Returns: + - loss (Tensor): Loss scalar tensor. + - stats (Dict[str, float]): Statistics to be monitored. + """ + if forward_generator: + return self._forward_generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + speech=speech, + speech_lengths=speech_lengths, + return_sample=return_sample, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + return self._forward_discrminator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + speech=speech, + speech_lengths=speech_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + + def _forward_generator( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + return_sample: bool = False, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Perform generator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + """ + # setup + feats = feats.transpose(1, 2) + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + outs = self._cache + + # store cache + if self.training and self.cache_generator_outputs and not reuse_cache: + self._cache = outs + + # parse outputs + speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs + _, z_p, m_p, logs_p, _, logs_q = outs_ + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * self.generator.upsample_factor, + segment_size=self.generator.segment_size * self.generator.upsample_factor, + ) + + # calculate discriminator outputs + p_hat = self.discriminator(speech_hat_) + with torch.no_grad(): + # do not store discriminator gradient in generator turn + p = self.discriminator(speech_) + + # calculate losses + with autocast(enabled=False): + if not return_sample: + mel_loss = self.mel_loss(speech_hat_, speech_) + else: + mel_loss, (mel_hat_, mel_) = self.mel_loss( + speech_hat_, speech_, return_mel=True + ) + kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) + dur_loss = torch.sum(dur_nll.float()) + adv_loss = self.generator_adv_loss(p_hat) + feat_match_loss = self.feat_match_loss(p_hat, p) + + mel_loss = mel_loss * self.lambda_mel + kl_loss = kl_loss * self.lambda_kl + dur_loss = dur_loss * self.lambda_dur + adv_loss = adv_loss * self.lambda_adv + feat_match_loss = feat_match_loss * self.lambda_feat_match + loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss + + stats = dict( + generator_loss=loss.item(), + generator_mel_loss=mel_loss.item(), + generator_kl_loss=kl_loss.item(), + generator_dur_loss=dur_loss.item(), + generator_adv_loss=adv_loss.item(), + generator_feat_match_loss=feat_match_loss.item(), + ) + + if return_sample: + stats["returned_sample"] = ( + speech_hat_[0].data.cpu().numpy(), + speech_[0].data.cpu().numpy(), + mel_hat_[0].data.cpu().numpy(), + mel_[0].data.cpu().numpy(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return loss, stats + + def _forward_discrminator( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Perform discriminator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + """ + # setup + feats = feats.transpose(1, 2) + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + outs = self._cache + + # store cache + if self.cache_generator_outputs and not reuse_cache: + self._cache = outs + + # parse outputs + speech_hat_, _, _, start_idxs, *_ = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * self.generator.upsample_factor, + segment_size=self.generator.segment_size * self.generator.upsample_factor, + ) + + # calculate discriminator outputs + p_hat = self.discriminator(speech_hat_.detach()) + p = self.discriminator(speech_) + + # calculate losses + with autocast(enabled=False): + real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) + loss = real_loss + fake_loss + + stats = dict( + discriminator_loss=loss.item(), + discriminator_real_loss=real_loss.item(), + discriminator_fake_loss=fake_loss.item(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return loss, stats + + def inference( + self, + text: torch.Tensor, + feats: Optional[torch.Tensor] = None, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + durations: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference for single sample. + + Args: + text (Tensor): Input text index tensor (T_text,). + feats (Tensor): Feature tensor (T_feats, aux_channels). + sids (Tensor): Speaker index tensor (1,). + spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,). + lids (Tensor): Language index tensor (1,). + durations (Tensor): Ground-truth duration tensor (T_text,). + noise_scale (float): Noise scale value for flow. + noise_scale_dur (float): Noise scale value for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length. + use_teacher_forcing (bool): Whether to use teacher forcing. + + Returns: + * wav (Tensor): Generated waveform tensor (T_wav,). + * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text). + * duration (Tensor): Predicted duration tensor (T_text,). + """ + # setup + text = text[None] + text_lengths = torch.tensor( + [text.size(1)], + dtype=torch.long, + device=text.device, + ) + if sids is not None: + sids = sids.view(1) + if lids is not None: + lids = lids.view(1) + if durations is not None: + durations = durations.view(1, 1, -1) + + # inference + if use_teacher_forcing: + assert feats is not None + feats = feats[None].transpose(1, 2) + feats_lengths = torch.tensor( + [feats.size(2)], + dtype=torch.long, + device=feats.device, + ) + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + max_len=max_len, + use_teacher_forcing=use_teacher_forcing, + ) + else: + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + sids=sids, + spembs=spembs, + lids=lids, + dur=durations, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + max_len=max_len, + ) + return wav.view(-1), att_w[0], dur[0] + + def inference_batch( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + durations: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference for one batch. + + Args: + text (Tensor): Input text index tensor (B, T_text). + text_lengths (Tensor): Input text index tensor (B,). + sids (Tensor): Speaker index tensor (B,). + noise_scale (float): Noise scale value for flow. + noise_scale_dur (float): Noise scale value for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length. + + Returns: + * wav (Tensor): Generated waveform tensor (B, T_wav). + * att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text). + * duration (Tensor): Predicted duration tensor (B, T_text). + """ + # inference + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + sids=sids, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + max_len=max_len, + ) + return wav, att_w, dur diff --git a/egs/ljspeech/TTS/vits/wavenet.py b/egs/ljspeech/TTS/vits/wavenet.py new file mode 100644 index 000000000..fbe1be52b --- /dev/null +++ b/egs/ljspeech/TTS/vits/wavenet.py @@ -0,0 +1,349 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""WaveNet modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +import math +import logging + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + + +class WaveNet(torch.nn.Module): + """WaveNet with global conditioning.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + kernel_size: int = 3, + layers: int = 30, + stacks: int = 3, + base_dilation: int = 2, + residual_channels: int = 64, + aux_channels: int = -1, + gate_channels: int = 128, + skip_channels: int = 64, + global_channels: int = -1, + dropout_rate: float = 0.0, + bias: bool = True, + use_weight_norm: bool = True, + use_first_conv: bool = False, + use_last_conv: bool = False, + scale_residual: bool = False, + scale_skip_connect: bool = False, + ): + """Initialize WaveNet module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Kernel size of dilated convolution. + layers (int): Number of residual block layers. + stacks (int): Number of stacks i.e., dilation cycles. + base_dilation (int): Base dilation factor. + residual_channels (int): Number of channels in residual conv. + gate_channels (int): Number of channels in gated conv. + skip_channels (int): Number of channels in skip conv. + aux_channels (int): Number of channels for local conditioning feature. + global_channels (int): Number of channels for global conditioning feature. + dropout_rate (float): Dropout rate. 0.0 means no dropout applied. + bias (bool): Whether to use bias parameter in conv layer. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + use_first_conv (bool): Whether to use the first conv layers. + use_last_conv (bool): Whether to use the last conv layers. + scale_residual (bool): Whether to scale the residual outputs. + scale_skip_connect (bool): Whether to scale the skip connection outputs. + + """ + super().__init__() + self.layers = layers + self.stacks = stacks + self.kernel_size = kernel_size + self.base_dilation = base_dilation + self.use_first_conv = use_first_conv + self.use_last_conv = use_last_conv + self.scale_skip_connect = scale_skip_connect + + # check the number of layers and stacks + assert layers % stacks == 0 + layers_per_stack = layers // stacks + + # define first convolution + if self.use_first_conv: + self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True) + + # define residual blocks + self.conv_layers = torch.nn.ModuleList() + for layer in range(layers): + dilation = base_dilation ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + residual_channels=residual_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + global_channels=global_channels, + dilation=dilation, + dropout_rate=dropout_rate, + bias=bias, + scale_residual=scale_residual, + ) + self.conv_layers += [conv] + + # define output layers + if self.use_last_conv: + self.last_conv = torch.nn.Sequential( + torch.nn.ReLU(inplace=True), + Conv1d1x1(skip_channels, skip_channels, bias=True), + torch.nn.ReLU(inplace=True), + Conv1d1x1(skip_channels, out_channels, bias=True), + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward( + self, + x: torch.Tensor, + x_mask: Optional[torch.Tensor] = None, + c: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T) if use_first_conv else + (B, residual_channels, T). + x_mask (Optional[Tensor]): Mask tensor (B, 1, T). + c (Optional[Tensor]): Local conditioning features (B, aux_channels, T). + g (Optional[Tensor]): Global conditioning features (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, out_channels, T) if use_last_conv else + (B, residual_channels, T). + + """ + # encode to hidden representation + if self.use_first_conv: + x = self.first_conv(x) + + # residual block + skips = 0.0 + for f in self.conv_layers: + x, h = f(x, x_mask=x_mask, c=c, g=g) + skips = skips + h + x = skips + if self.scale_skip_connect: + x = x * math.sqrt(1.0 / len(self.conv_layers)) + + # apply final layers + if self.use_last_conv: + x = self.last_conv(x) + + return x + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m: torch.nn.Module): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + @staticmethod + def _get_receptive_field_size( + layers: int, + stacks: int, + kernel_size: int, + base_dilation: int, + ) -> int: + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + @property + def receptive_field_size(self) -> int: + """Return receptive field size.""" + return self._get_receptive_field_size( + self.layers, self.stacks, self.kernel_size, self.base_dilation + ) + + +class Conv1d(torch.nn.Conv1d): + """Conv1d module with customized initialization.""" + + def __init__(self, *args, **kwargs): + """Initialize Conv1d module.""" + super().__init__(*args, **kwargs) + + def reset_parameters(self): + """Reset parameters.""" + torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") + if self.bias is not None: + torch.nn.init.constant_(self.bias, 0.0) + + +class Conv1d1x1(Conv1d): + """1x1 Conv1d with customized initialization.""" + + def __init__(self, in_channels: int, out_channels: int, bias: bool): + """Initialize 1x1 Conv1d module.""" + super().__init__( + in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias + ) + + +class ResidualBlock(torch.nn.Module): + """Residual block module in WaveNet.""" + + def __init__( + self, + kernel_size: int = 3, + residual_channels: int = 64, + gate_channels: int = 128, + skip_channels: int = 64, + aux_channels: int = 80, + global_channels: int = -1, + dropout_rate: float = 0.0, + dilation: int = 1, + bias: bool = True, + scale_residual: bool = False, + ): + """Initialize ResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + residual_channels (int): Number of channels for residual connection. + skip_channels (int): Number of channels for skip connection. + aux_channels (int): Number of local conditioning channels. + dropout (float): Dropout probability. + dilation (int): Dilation factor. + bias (bool): Whether to add bias parameter in convolution layers. + scale_residual (bool): Whether to scale the residual outputs. + + """ + super().__init__() + self.dropout_rate = dropout_rate + self.residual_channels = residual_channels + self.skip_channels = skip_channels + self.scale_residual = scale_residual + + # check + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + assert gate_channels % 2 == 0 + + # dilation conv + padding = (kernel_size - 1) // 2 * dilation + self.conv = Conv1d( + residual_channels, + gate_channels, + kernel_size, + padding=padding, + dilation=dilation, + bias=bias, + ) + + # local conditioning + if aux_channels > 0: + self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) + else: + self.conv1x1_aux = None + + # global conditioning + if global_channels > 0: + self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False) + else: + self.conv1x1_glo = None + + # conv output is split into two groups + gate_out_channels = gate_channels // 2 + + # NOTE(kan-bayashi): concat two convs into a single conv for the efficiency + # (integrate res 1x1 + skip 1x1 convs) + self.conv1x1_out = Conv1d1x1( + gate_out_channels, residual_channels + skip_channels, bias=bias + ) + + def forward( + self, + x: torch.Tensor, + x_mask: Optional[torch.Tensor] = None, + c: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, residual_channels, T). + x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T). + c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor for residual connection (B, residual_channels, T). + Tensor: Output tensor for skip connection (B, skip_channels, T). + + """ + residual = x + x = F.dropout(x, p=self.dropout_rate, training=self.training) + x = self.conv(x) + + # split into two part for gated activation + splitdim = 1 + xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) + + # local conditioning + if c is not None: + c = self.conv1x1_aux(c) + ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ca, xb + cb + + # global conditioning + if g is not None: + g = self.conv1x1_glo(g) + ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ga, xb + gb + + x = torch.tanh(xa) * torch.sigmoid(xb) + + # residual + skip 1x1 conv + x = self.conv1x1_out(x) + if x_mask is not None: + x = x * x_mask + + # split integrated conv results + x, s = x.split([self.residual_channels, self.skip_channels], dim=1) + + # for residual connection + x = x + residual + if self.scale_residual: + x = x * math.sqrt(0.5) + + return x, s diff --git a/pyproject.toml b/pyproject.toml index c40143fb9..435256416 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,4 +14,5 @@ exclude = ''' | icefall\/diagnostics\.py | icefall\/profiler\.py | egs\/librispeech\/ASR\/zipformer + | egs\/ljspeech\/TTS\/vits ''' From f08af2fa2217e226394b6f03442952d104bd984e Mon Sep 17 00:00:00 2001 From: LoganLiu66 <2319277867@qq.com> Date: Mon, 4 Dec 2023 22:29:42 +0800 Subject: [PATCH 107/113] fix initial states (#1398) Co-authored-by: liujiawang02 --- .../pruned_transducer_stateless7_streaming/decode_stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py index 0d7e86fcf..2c4b144fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -82,12 +82,12 @@ class DecodeStream(object): self.pad_length = 7 if params.decoding_method == "greedy_search": - self.hyp = [params.blank_id] * params.context_size + self.hyp = [-1] * (params.context_size - 1) + [params.blank_id] elif params.decoding_method == "modified_beam_search": self.hyps = HypothesisList() self.hyps.add( Hypothesis( - ys=[params.blank_id] * params.context_size, + ys=[-1] * (params.context_size - 1) + [params.blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), ) ) From 735fb9a73dea7d27e95056add6598ae7a282d6f9 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 6 Dec 2023 09:59:19 +0800 Subject: [PATCH 108/113] A TTS recipe VITS on VCTK dataset (#1380) * init * isort formatted * minor updates * Create shared * Update prepare_tokens_vctk.py * Update prepare_tokens_vctk.py * Update prepare_tokens_vctk.py * Update prepare.sh * updated * Update train.py * Update train.py * Update tts_datamodule.py * Update train.py * Update train.py * Update train.py * Update train.py * Update train.py * Update train.py * fixed formatting issue * Update infer.py * removed redundant files * Create monotonic_align * removed redundant files * created symlinks * Update prepare.sh * minor adjustments * Create requirements_tts.txt * Update requirements_tts.txt added version constraints * Update infer.py * Update infer.py * Update infer.py * updated docs * Update export-onnx.py * Update export-onnx.py * Update test_onnx.py * updated requirements.txt * Update test_onnx.py * Update test_onnx.py * docs updated * docs fixed * minor updates --- docs/source/recipes/TTS/index.rst | 1 + docs/source/recipes/TTS/ljspeech/vits.rst | 12 +- docs/source/recipes/TTS/vctk/vits.rst | 125 +++ egs/ljspeech/TTS/prepare.sh | 16 +- egs/ljspeech/TTS/vits/duration_predictor.py | 1 - egs/ljspeech/TTS/vits/export-onnx.py | 8 +- egs/ljspeech/TTS/vits/flow.py | 1 - egs/ljspeech/TTS/vits/generator.py | 5 +- egs/ljspeech/TTS/vits/infer.py | 29 +- egs/ljspeech/TTS/vits/loss.py | 1 - egs/ljspeech/TTS/vits/posterior_encoder.py | 2 +- egs/ljspeech/TTS/vits/residual_coupling.py | 1 - egs/ljspeech/TTS/vits/test_onnx.py | 2 +- egs/ljspeech/TTS/vits/text_encoder.py | 48 +- egs/ljspeech/TTS/vits/tokenizer.py | 4 +- egs/ljspeech/TTS/vits/train.py | 93 +- egs/ljspeech/TTS/vits/tts_datamodule.py | 2 +- egs/ljspeech/TTS/vits/utils.py | 14 +- egs/ljspeech/TTS/vits/vits.py | 9 +- egs/ljspeech/TTS/vits/wavenet.py | 3 +- .../TTS/local/compute_spectrogram_vctk.py | 107 ++ .../TTS/local/display_manifest_statistics.py | 83 ++ egs/vctk/TTS/local/prepare_token_file.py | 104 ++ egs/vctk/TTS/local/prepare_tokens_vctk.py | 61 + egs/vctk/TTS/local/validate_manifest.py | 70 ++ egs/vctk/TTS/prepare.sh | 131 +++ egs/vctk/TTS/shared | 1 + egs/vctk/TTS/vits/duration_predictor.py | 1 + egs/vctk/TTS/vits/export-onnx.py | 284 +++++ egs/vctk/TTS/vits/flow.py | 1 + egs/vctk/TTS/vits/generator.py | 1 + egs/vctk/TTS/vits/hifigan.py | 1 + egs/vctk/TTS/vits/infer.py | 272 +++++ egs/vctk/TTS/vits/loss.py | 1 + egs/vctk/TTS/vits/monotonic_align | 1 + egs/vctk/TTS/vits/posterior_encoder.py | 1 + egs/vctk/TTS/vits/residual_coupling.py | 1 + egs/vctk/TTS/vits/test_onnx.py | 138 +++ egs/vctk/TTS/vits/text_encoder.py | 1 + egs/vctk/TTS/vits/tokenizer.py | 1 + egs/vctk/TTS/vits/train.py | 1000 +++++++++++++++++ egs/vctk/TTS/vits/transform.py | 1 + egs/vctk/TTS/vits/tts_datamodule.py | 338 ++++++ egs/vctk/TTS/vits/utils.py | 1 + egs/vctk/TTS/vits/vits.py | 1 + egs/vctk/TTS/vits/wavenet.py | 1 + requirements-tts.txt | 6 + requirements.txt | 2 + 48 files changed, 2904 insertions(+), 84 deletions(-) create mode 100644 docs/source/recipes/TTS/vctk/vits.rst create mode 100755 egs/vctk/TTS/local/compute_spectrogram_vctk.py create mode 100755 egs/vctk/TTS/local/display_manifest_statistics.py create mode 100755 egs/vctk/TTS/local/prepare_token_file.py create mode 100755 egs/vctk/TTS/local/prepare_tokens_vctk.py create mode 100755 egs/vctk/TTS/local/validate_manifest.py create mode 100755 egs/vctk/TTS/prepare.sh create mode 120000 egs/vctk/TTS/shared create mode 120000 egs/vctk/TTS/vits/duration_predictor.py create mode 100755 egs/vctk/TTS/vits/export-onnx.py create mode 120000 egs/vctk/TTS/vits/flow.py create mode 120000 egs/vctk/TTS/vits/generator.py create mode 120000 egs/vctk/TTS/vits/hifigan.py create mode 100755 egs/vctk/TTS/vits/infer.py create mode 120000 egs/vctk/TTS/vits/loss.py create mode 120000 egs/vctk/TTS/vits/monotonic_align create mode 120000 egs/vctk/TTS/vits/posterior_encoder.py create mode 120000 egs/vctk/TTS/vits/residual_coupling.py create mode 100755 egs/vctk/TTS/vits/test_onnx.py create mode 120000 egs/vctk/TTS/vits/text_encoder.py create mode 120000 egs/vctk/TTS/vits/tokenizer.py create mode 100755 egs/vctk/TTS/vits/train.py create mode 120000 egs/vctk/TTS/vits/transform.py create mode 100644 egs/vctk/TTS/vits/tts_datamodule.py create mode 120000 egs/vctk/TTS/vits/utils.py create mode 120000 egs/vctk/TTS/vits/vits.py create mode 120000 egs/vctk/TTS/vits/wavenet.py create mode 100644 requirements-tts.txt diff --git a/docs/source/recipes/TTS/index.rst b/docs/source/recipes/TTS/index.rst index aa891c072..80d67a2f3 100644 --- a/docs/source/recipes/TTS/index.rst +++ b/docs/source/recipes/TTS/index.rst @@ -5,3 +5,4 @@ TTS :maxdepth: 2 ljspeech/vits + vctk/vits \ No newline at end of file diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst index 385fd3c70..d08aa0f47 100644 --- a/docs/source/recipes/TTS/ljspeech/vits.rst +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -4,6 +4,10 @@ VITS This tutorial shows you how to train an VITS model with the `LJSpeech `_ dataset. +.. note:: + + TTS related recipes require packages in ``requirements-tts.txt``. + .. note:: The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_ @@ -27,6 +31,12 @@ To run stage 1 to stage 5, use Build Monotonic Alignment Search -------------------------------- +.. code-block:: bash + + $ ./prepare.sh --stage -1 --stop_stage -1 + +or + .. code-block:: bash $ cd vits/monotonic_align @@ -74,7 +84,7 @@ training part first. It will save the ground-truth and generated wavs to the dir $ ./vits/infer.py \ --epoch 1000 \ --exp-dir vits/exp \ - --tokens data/tokens.txt + --tokens data/tokens.txt \ --max-duration 500 .. note:: diff --git a/docs/source/recipes/TTS/vctk/vits.rst b/docs/source/recipes/TTS/vctk/vits.rst new file mode 100644 index 000000000..34024a5ea --- /dev/null +++ b/docs/source/recipes/TTS/vctk/vits.rst @@ -0,0 +1,125 @@ +VITS +=============== + +This tutorial shows you how to train an VITS model +with the `VCTK `_ dataset. + +.. note:: + + TTS related recipes require packages in ``requirements-tts.txt``. + +.. note:: + + The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_ + + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/vctk/TTS + $ ./prepare.sh + +To run stage 1 to stage 6, use + +.. code-block:: bash + + $ ./prepare.sh --stage 1 --stop_stage 6 + + +Build Monotonic Alignment Search +-------------------------------- + +To build the monotonic alignment search, use the following commands: + +.. code-block:: bash + + $ ./prepare.sh --stage -1 --stop_stage -1 + +or + +.. code-block:: bash + + $ cd vits/monotonic_align + $ python setup.py build_ext --inplace + $ cd ../../ + + +Training +-------- + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0,1,2,3" + $ ./vits/train.py \ + --world-size 4 \ + --num-epochs 1000 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + --max-duration 350 + +.. note:: + + You can adjust the hyper-parameters to control the size of the VITS model and + the training configurations. For more details, please run ``./vits/train.py --help``. + +.. note:: + + The training can take a long time (usually a couple of days). + +Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``. + + +Inference +--------- + +The inference part uses checkpoints saved by the training part, so you have to run the +training part first. It will save the ground-truth and generated wavs to the directory +``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``. + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0" + $ ./vits/infer.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt \ + --max-duration 500 + +.. note:: + + For more details, please run ``./vits/infer.py --help``. + + +Export models +------------- + +Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``: +``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``. + +.. code-block:: bash + + $ ./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +You can test the exported ONNX model with: + +.. code-block:: bash + + $ ./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following link: + + - ``_ diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index 8ee40896e..ed0a07f5e 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -5,8 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail -nj=1 -stage=-1 +stage=0 stop_stage=100 dl_dir=$PWD/download @@ -25,6 +24,17 @@ log() { log "dl_dir: $dl_dir" +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: build monotonic_align lib" + if [ ! -d vits/monotonic_align/build ]; then + cd vits/monotonic_align + python setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib already built" + fi +fi + if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Download data" @@ -113,5 +123,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then --tokens data/tokens.txt fi fi - - diff --git a/egs/ljspeech/TTS/vits/duration_predictor.py b/egs/ljspeech/TTS/vits/duration_predictor.py index c29a28479..1a8190014 100644 --- a/egs/ljspeech/TTS/vits/duration_predictor.py +++ b/egs/ljspeech/TTS/vits/duration_predictor.py @@ -14,7 +14,6 @@ from typing import Optional import torch import torch.nn.functional as F - from flow import ( ConvFlow, DilatedDepthSeparableConv, diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 154de4bf4..2068adeea 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -180,7 +180,13 @@ def export_model_onnx( model_filename, verbose=False, opset_version=opset_version, - input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"], + input_names=[ + "tokens", + "tokens_lens", + "noise_scale", + "noise_scale_dur", + "alpha", + ], output_names=["audio"], dynamic_axes={ "tokens": {0: "N", 1: "T"}, diff --git a/egs/ljspeech/TTS/vits/flow.py b/egs/ljspeech/TTS/vits/flow.py index 206bd5e3e..2b84f6434 100644 --- a/egs/ljspeech/TTS/vits/flow.py +++ b/egs/ljspeech/TTS/vits/flow.py @@ -13,7 +13,6 @@ import math from typing import Optional, Tuple, Union import torch - from transform import piecewise_rational_quadratic_transform diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py index efb0e254c..66c8cedb1 100644 --- a/egs/ljspeech/TTS/vits/generator.py +++ b/egs/ljspeech/TTS/vits/generator.py @@ -16,9 +16,6 @@ from typing import List, Optional, Tuple import numpy as np import torch import torch.nn.functional as F - -from icefall.utils import make_pad_mask - from duration_predictor import StochasticDurationPredictor from hifigan import HiFiGANGenerator from posterior_encoder import PosteriorEncoder @@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock from text_encoder import TextEncoder from utils import get_random_segments +from icefall.utils import make_pad_mask + class VITSGenerator(torch.nn.Module): """Generator module in VITS, `Conditional Variational Autoencoder diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index 91a35e360..cf0d20ae2 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -36,13 +36,12 @@ import k2 import torch import torch.nn as nn import torchaudio - -from train import get_model, get_params from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import LJSpeechTtsDataModule from icefall.checkpoint import load_checkpoint from icefall.utils import AttributeDict, setup_logger -from tts_datamodule import LJSpeechTtsDataModule def get_parser(): @@ -107,12 +106,12 @@ def infer_dataset( for i in range(batch_size): torchaudio.save( str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"), - audio[i:i + 1, :audio_lens[i]], + audio[i : i + 1, : audio_lens[i]], sample_rate=params.sampling_rate, ) torchaudio.save( str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"), - audio_pred[i:i + 1, :audio_lens_pred[i]], + audio_pred[i : i + 1, : audio_lens_pred[i]], sample_rate=params.sampling_rate, ) @@ -144,14 +143,24 @@ def infer_dataset( audio_lens = batch["audio_lens"].tolist() cut_ids = [cut.id for cut in batch["cut"]] - audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens) + audio_pred, _, durations = model.inference_batch( + text=tokens, text_lengths=tokens_lens + ) audio_pred = audio_pred.detach().cpu() # convert to samples - audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + audio_lens_pred = ( + (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + ) futures.append( executor.submit( - _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred + _save_worker, + batch_size, + cut_ids, + audio, + audio_pred, + audio_lens, + audio_lens_pred, ) ) @@ -160,7 +169,9 @@ def infer_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) # return results for f in futures: f.result() diff --git a/egs/ljspeech/TTS/vits/loss.py b/egs/ljspeech/TTS/vits/loss.py index 21aaad6e7..2f4dc9bc0 100644 --- a/egs/ljspeech/TTS/vits/loss.py +++ b/egs/ljspeech/TTS/vits/loss.py @@ -14,7 +14,6 @@ from typing import List, Tuple, Union import torch import torch.distributions as D import torch.nn.functional as F - from lhotse.features.kaldi import Wav2LogFilterBank diff --git a/egs/ljspeech/TTS/vits/posterior_encoder.py b/egs/ljspeech/TTS/vits/posterior_encoder.py index 6b8a5be52..1104fb864 100644 --- a/egs/ljspeech/TTS/vits/posterior_encoder.py +++ b/egs/ljspeech/TTS/vits/posterior_encoder.py @@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits. from typing import Optional, Tuple import torch +from wavenet import Conv1d, WaveNet from icefall.utils import make_pad_mask -from wavenet import WaveNet, Conv1d class PosteriorEncoder(torch.nn.Module): diff --git a/egs/ljspeech/TTS/vits/residual_coupling.py b/egs/ljspeech/TTS/vits/residual_coupling.py index 2d6807cb7..f9a2a3786 100644 --- a/egs/ljspeech/TTS/vits/residual_coupling.py +++ b/egs/ljspeech/TTS/vits/residual_coupling.py @@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits. from typing import Optional, Tuple, Union import torch - from flow import FlipFlow from wavenet import WaveNet diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py index 8acca7c02..686fee2a0 100755 --- a/egs/ljspeech/TTS/vits/test_onnx.py +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -28,10 +28,10 @@ Use the onnx model to generate a wav: import argparse import logging + import onnxruntime as ort import torch import torchaudio - from tokenizer import Tokenizer diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py index 9f337e45b..fcbae7103 100644 --- a/egs/ljspeech/TTS/vits/text_encoder.py +++ b/egs/ljspeech/TTS/vits/text_encoder.py @@ -169,9 +169,7 @@ class Transformer(nn.Module): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - x = self.encoder( - x, pos_emb, key_padding_mask=key_padding_mask - ) # (T, N, C) + x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C) x = self.after_norm(x) @@ -207,7 +205,9 @@ class TransformerEncoderLayer(nn.Module): nn.Linear(dim_feedforward, d_model), ) - self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout) + self.self_attn = RelPositionMultiheadAttention( + d_model, num_heads, dropout=dropout + ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) @@ -242,7 +242,9 @@ class TransformerEncoderLayer(nn.Module): key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) """ # macaron style feed-forward module - src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src))) + src = src + self.ff_scale * self.dropout( + self.feed_forward_macaron(self.norm_ff_macaron(src)) + ) # multi-head self-attention module src_attn = self.self_attn( @@ -490,11 +492,17 @@ class RelPositionMultiheadAttention(nn.Module): q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) - v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + v = ( + v.contiguous() + .view(seq_len, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim) - p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim) + p = self.linear_pos(pos_emb).view( + pos_emb.size(0), -1, self.num_heads, self.head_dim + ) # (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1) p = p.permute(0, 2, 3, 1) @@ -506,15 +514,23 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch_size, num_head, seq_len, seq_len) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1) - matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len) + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch_size, num_head, seq_len, 2*seq_len-1) + matrix_bd = self.rel_shift( + matrix_bd + ) # (batch_size, num_head, seq_len, seq_len) # (batch_size, num_head, seq_len, seq_len) attn_output_weights = (matrix_ac + matrix_bd) * scaling - attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len) + attn_output_weights = attn_output_weights.view( + batch_size * self.num_heads, seq_len, seq_len + ) if key_padding_mask is not None: assert key_padding_mask.shape == (batch_size, seq_len) @@ -536,10 +552,16 @@ class RelPositionMultiheadAttention(nn.Module): # (batch_size * num_head, seq_len, head_dim) attn_output = torch.bmm(attn_output_weights, v) - assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim) + assert attn_output.shape == ( + batch_size * self.num_heads, + seq_len, + self.head_dim, + ) attn_output = ( - attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, batch_size, self.embed_dim) ) # (seq_len, batch_size, embed_dim) attn_output = self.out_proj(attn_output) diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 0678b26fe..70f1240b4 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -78,7 +78,9 @@ class Tokenizer(object): return token_ids_list - def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True): + def tokens_to_token_ids( + self, tokens_list: List[str], intersperse_blank: bool = True + ): """ Args: tokens_list: diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index eb43a4cc9..71c4224fa 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -18,21 +18,25 @@ import argparse import logging -import numpy as np from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union import k2 +import numpy as np import torch import torch.multiprocessing as mp import torch.nn as nn from lhotse.cut import Cut from lhotse.utils import fix_random_seed -from torch.optim import Optimizer +from tokenizer import Tokenizer from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import LJSpeechTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS from icefall import diagnostics from icefall.checkpoint import load_checkpoint @@ -41,11 +45,6 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, setup_logger, str2bool -from tokenizer import Tokenizer -from tts_datamodule import LJSpeechTtsDataModule -from utils import MetricsTracker, plot_feature, save_checkpoint -from vits import VITS - LRSchedulerType = torch.optim.lr_scheduler._LRScheduler @@ -385,11 +384,12 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["tokens"]) - audio, audio_lens, features, features_lens, tokens, tokens_lens = \ - prepare_input(batch, tokenizer, device) + audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( + batch, tokenizer, device + ) loss_info = MetricsTracker() - loss_info['samples'] = batch_size + loss_info["samples"] = batch_size try: with autocast(enabled=params.use_fp16): @@ -446,7 +446,9 @@ def train_one_epoch( # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0): + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: @@ -482,9 +484,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train @@ -492,19 +492,34 @@ def train_one_epoch( if "returned_sample" in stats_g: speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] tb_writer.add_audio( - "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate + "train/speech_hat_", + speech_hat_, + params.batch_idx_train, + params.sampling_rate, ) tb_writer.add_audio( - "train/speech_", speech_, params.batch_idx_train, params.sampling_rate + "train/speech_", + speech_, + params.batch_idx_train, + params.sampling_rate, ) tb_writer.add_image( - "train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC' + "train/mel_hat_", + plot_feature(mel_hat_), + params.batch_idx_train, + dataformats="HWC", ) tb_writer.add_image( - "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC' + "train/mel_", + plot_feature(mel_), + params.batch_idx_train, + dataformats="HWC", ) - if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info, (speech_hat, speech) = compute_validation_loss( params=params, @@ -523,10 +538,16 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) tb_writer.add_audio( - "train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate + "train/valdi_speech_hat", + speech_hat, + params.batch_idx_train, + params.sampling_rate, ) tb_writer.add_audio( - "train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate + "train/valdi_speech", + speech, + params.batch_idx_train, + params.sampling_rate, ) loss_value = tot_loss["generator_loss"] / tot_loss["samples"] @@ -555,11 +576,17 @@ def compute_validation_loss( with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): batch_size = len(batch["tokens"]) - audio, audio_lens, features, features_lens, tokens, tokens_lens = \ - prepare_input(batch, tokenizer, device) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device) loss_info = MetricsTracker() - loss_info['samples'] = batch_size + loss_info["samples"] = batch_size # forward discriminator loss_d, stats_d = model( @@ -596,12 +623,17 @@ def compute_validation_loss( if batch_idx == 0 and rank == 0: inner_model = model.module if isinstance(model, DDP) else model audio_pred, _, duration = inner_model.inference( - text=tokens[0, :tokens_lens[0].item()] + text=tokens[0, : tokens_lens[0].item()] ) audio_pred = audio_pred.data.cpu().numpy() - audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() - assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred)) - audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy() + audio_len_pred = ( + (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + ) + assert audio_len_pred == len(audio_pred), ( + audio_len_pred, + len(audio_pred), + ) + audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() returned_sample = (audio_pred, audio_gt) if world_size > 1: @@ -632,8 +664,9 @@ def scan_pessimistic_batches_for_oom( batches, crit_values = find_pessimistic_batches(train_dl.sampler) for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] - audio, audio_lens, features, features_lens, tokens, tokens_lens = \ - prepare_input(batch, tokenizer, device) + audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( + batch, tokenizer, device + ) try: # for discriminator with autocast(enabled=params.use_fp16): diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py index 0fcbb92c1..81bb9ed13 100644 --- a/egs/ljspeech/TTS/vits/tts_datamodule.py +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, DynamicBucketingSampler, - SpeechSynthesisDataset, PrecomputedFeatures, SimpleCutSampler, SpecAugment, + SpeechSynthesisDataset, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py index 2a3dae900..6a067f596 100644 --- a/egs/ljspeech/TTS/vits/utils.py +++ b/egs/ljspeech/TTS/vits/utils.py @@ -14,15 +14,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union import collections import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn from lhotse.dataset.sampling.base import CutSampler -from pathlib import Path from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -97,23 +97,23 @@ def plot_feature(spectrogram): global MATPLOTLIB_FLAG if not MATPLOTLIB_FLAG: import matplotlib + matplotlib.use("Agg") MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger('matplotlib') + mpl_logger = logging.getLogger("matplotlib") mpl_logger.setLevel(logging.WARNING) import matplotlib.pylab as plt import numpy as np fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", - interpolation='none') + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") plt.colorbar(im, ax=ax) plt.xlabel("Frames") plt.ylabel("Channels") plt.tight_layout() fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index d5e20a578..b4f0c21e6 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn -from torch.cuda.amp import autocast - +from generator import VITSGenerator from hifigan import ( HiFiGANMultiPeriodDiscriminator, HiFiGANMultiScaleDiscriminator, @@ -25,9 +24,8 @@ from loss import ( KLDivergenceLoss, MelSpectrogramLoss, ) +from torch.cuda.amp import autocast from utils import get_segments -from generator import VITSGenerator - AVAILABLE_GENERATERS = { "vits_generator": VITSGenerator, @@ -42,8 +40,7 @@ AVAILABLE_DISCRIMINATORS = { class VITS(nn.Module): - """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech` - """ + """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`""" def __init__( self, diff --git a/egs/ljspeech/TTS/vits/wavenet.py b/egs/ljspeech/TTS/vits/wavenet.py index fbe1be52b..5db461d5c 100644 --- a/egs/ljspeech/TTS/vits/wavenet.py +++ b/egs/ljspeech/TTS/vits/wavenet.py @@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. """ -import math import logging - +import math from typing import Optional, Tuple import torch diff --git a/egs/vctk/TTS/local/compute_spectrogram_vctk.py b/egs/vctk/TTS/local/compute_spectrogram_vctk.py new file mode 100755 index 000000000..440ac1245 --- /dev/null +++ b/egs/vctk/TTS/local/compute_spectrogram_vctk.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the VCTK dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/spectrogram. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + LilcomChunkyWriter, + Spectrogram, + SpectrogramConfig, + load_manifest, +) +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_spectrogram_vctk(): + src_dir = Path("data/manifests") + output_dir = Path("data/spectrogram") + num_jobs = min(32, os.cpu_count()) + + sampling_rate = 22050 + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + use_fft_mag = True + + prefix = "vctk" + suffix = "jsonl.gz" + partition = "all" + + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet + ).resample(sampling_rate=sampling_rate) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet + ) + + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + ) + extractor = Spectrogram(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_spectrogram_vctk() diff --git a/egs/vctk/TTS/local/display_manifest_statistics.py b/egs/vctk/TTS/local/display_manifest_statistics.py new file mode 100755 index 000000000..0472e2cea --- /dev/null +++ b/egs/vctk/TTS/local/display_manifest_statistics.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in vits/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + path = "./data/spectrogram/vctk_cuts_all.jsonl.gz" + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 43873 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 41:02:18 │ +├───────────────────────────┼──────────┤ +│ mean │ 3.4 │ +├───────────────────────────┼──────────┤ +│ std │ 1.2 │ +├───────────────────────────┼──────────┤ +│ min │ 1.2 │ +├───────────────────────────┼──────────┤ +│ 25% │ 2.6 │ +├───────────────────────────┼──────────┤ +│ 50% │ 3.1 │ +├───────────────────────────┼──────────┤ +│ 75% │ 3.8 │ +├───────────────────────────┼──────────┤ +│ 99% │ 8.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 9.1 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 12.1 │ +├───────────────────────────┼──────────┤ +│ max │ 16.6 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 43873 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 43873 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 43873 │ +╘═══════════════════════════╧══════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 41:02:18 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 41:02:18 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:01 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +""" diff --git a/egs/vctk/TTS/local/prepare_token_file.py b/egs/vctk/TTS/local/prepare_token_file.py new file mode 100755 index 000000000..c6636c3ad --- /dev/null +++ b/egs/vctk/TTS/local/prepare_token_file.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and generates the file that maps tokens to IDs. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict + +from lhotse import load_manifest + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest-file", + type=Path, + default=Path("data/spectrogram/vctk_cuts_all.jsonl.gz"), + help="Path to the manifest file", + ) + + parser.add_argument( + "--tokens", + type=Path, + default=Path("data/tokens.txt"), + help="Path to the tokens", + ) + + return parser.parse_args() + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_token2id(manifest_file: Path) -> Dict[str, int]: + """Return a dict that maps token to IDs.""" + extra_tokens = [ + "", # 0 for blank + "", # 1 for sos and eos symbols. + "", # 2 for OOV + ] + all_tokens = set() + + cut_set = load_manifest(manifest_file) + + for cut in cut_set: + # Each cut only contain one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + for t in cut.tokens: + all_tokens.add(t) + + all_tokens = extra_tokens + list(all_tokens) + + token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)} + return token2id + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + manifest_file = Path(args.manifest_file) + out_file = Path(args.tokens) + + token2id = get_token2id(manifest_file) + write_mapping(out_file, token2id) diff --git a/egs/vctk/TTS/local/prepare_tokens_vctk.py b/egs/vctk/TTS/local/prepare_tokens_vctk.py new file mode 100755 index 000000000..32e1c7dfa --- /dev/null +++ b/egs/vctk/TTS/local/prepare_tokens_vctk.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with phoneme tokens. +""" + +import logging +from pathlib import Path + +import g2p_en +import tacotron_cleaner.cleaners +from lhotse import CutSet, load_manifest +from tqdm.auto import tqdm + + +def prepare_tokens_vctk(): + output_dir = Path("data/spectrogram") + prefix = "vctk" + suffix = "jsonl.gz" + partition = "all" + + cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + g2p = g2p_en.G2p() + + new_cuts = [] + for cut in tqdm(cut_set): + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + text = cut.supervisions[0].text + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + cut.tokens = g2p(text) + new_cuts.append(cut) + + new_cut_set = CutSet.from_cuts(new_cuts) + new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_tokens_vctk() diff --git a/egs/vctk/TTS/local/validate_manifest.py b/egs/vctk/TTS/local/validate_manifest.py new file mode 100755 index 000000000..cd466303e --- /dev/null +++ b/egs/vctk/TTS/local/validate_manifest.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/spectrogram/ljspeech_cuts_all.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset.speech_synthesis import validate_for_tts + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet) + + validate_for_tts(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/vctk/TTS/prepare.sh b/egs/vctk/TTS/prepare.sh new file mode 100755 index 000000000..87150ad31 --- /dev/null +++ b/egs/vctk/TTS/prepare.sh @@ -0,0 +1,131 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=0 +stop_stage=100 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: build monotonic_align lib" + if [ ! -d vits/monotonic_align/build ]; then + cd vits/monotonic_align + python setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib already built" + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/VCTK, + # you can create a symlink + # + # ln -sfv /path/to/VCTK $dl_dir/VCTK + # + if [ ! -d $dl_dir/VCTK ]; then + lhotse download vctk $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare VCTK manifest" + # We assume that you have downloaded the VCTK corpus + # to $dl_dir/VCTK + mkdir -p data/manifests + if [ ! -e data/manifests/.vctk.done ]; then + lhotse prepare vctk --use-edinburgh-vctk-url true $dl_dir/VCTK data/manifests + touch data/manifests/.vctk.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute spectrogram for VCTK" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.vctk.done ]; then + ./local/compute_spectrogram_vctk.py + touch data/spectrogram/.vctk.done + fi + + if [ ! -e data/spectrogram/.vctk-validated.done ]; then + log "Validating data/fbank for VCTK" + ./local/validate_manifest.py \ + data/spectrogram/vctk_cuts_all.jsonl.gz + touch data/spectrogram/.vctk-validated.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare phoneme tokens for VCTK" + if [ ! -e data/spectrogram/.vctk_with_token.done ]; then + ./local/prepare_tokens_vctk.py + mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \ + data/spectrogram/vctk_cuts_all.jsonl.gz + touch data/spectrogram/.vctk_with_token.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Split the VCTK cuts into train, valid and test sets" + if [ ! -e data/spectrogram/.vctk_split.done ]; then + lhotse subset --last 600 \ + data/spectrogram/vctk_cuts_all.jsonl.gz \ + data/spectrogram/vctk_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/spectrogram/vctk_cuts_validtest.jsonl.gz \ + data/spectrogram/vctk_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/spectrogram/vctk_cuts_validtest.jsonl.gz \ + data/spectrogram/vctk_cuts_test.jsonl.gz + + rm data/spectrogram/vctk_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/spectrogram/vctk_cuts_all.jsonl.gz | wc -l) - 600 )) + lhotse subset --first $n \ + data/spectrogram/vctk_cuts_all.jsonl.gz \ + data/spectrogram/vctk_cuts_train.jsonl.gz + touch data/spectrogram/.vctk_split.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate token file" + # We assume you have installed g2p_en and espnet_tts_frontend. + # If not, please install them with: + # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p + # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ + if [ ! -e data/tokens.txt ]; then + ./local/prepare_token_file.py \ + --manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \ + --tokens data/tokens.txt + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Generate speakers file" + if [ ! -e data/speakers.txt ]; then + gunzip -c data/manifests/vctk_supervisions_all.jsonl.gz \ + | jq '.speaker' | sed 's/"//g' \ + | sort | uniq > data/speakers.txt + fi +fi diff --git a/egs/vctk/TTS/shared b/egs/vctk/TTS/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/vctk/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/vctk/TTS/vits/duration_predictor.py b/egs/vctk/TTS/vits/duration_predictor.py new file mode 120000 index 000000000..9972b476f --- /dev/null +++ b/egs/vctk/TTS/vits/duration_predictor.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/duration_predictor.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py new file mode 100755 index 000000000..7c9664cc1 --- /dev/null +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports a VITS model from PyTorch to ONNX. + +Export the model to ONNX: +./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +It will generate two files inside vits/exp: + - vits-epoch-1000.onnx + - vits-epoch-1000.int8.onnx (quantizated model) + +See ./test_onnx.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import torch +import torch.nn as nn +from onnxruntime.quantization import QuantType, quantize_dynamic +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for VITS generator.""" + + def __init__(self, model: nn.Module): + """ + Args: + model: + A VITS generator. + frame_shift: + The frame shift in samples. + """ + super().__init__() + self.model = model + + def forward( + self, + tokens: torch.Tensor, + tokens_lens: torch.Tensor, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + speaker: int = 20, + alpha: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of VITS.inference_batch + + Args: + tokens: + Input text token indexes (1, T_text) + tokens_lens: + Number of tokens of shape (1,) + noise_scale (float): + Noise scale parameter for flow. + noise_scale_dur (float): + Noise scale parameter for duration predictor. + speaker (int): + Speaker ID. + alpha (float): + Alpha parameter to control the speed of generated speech. + + Returns: + Return a tuple containing: + - audio, generated wavform tensor, (B, T_wav) + """ + audio, _, _ = self.model.inference( + text=tokens, + text_lengths=tokens_lens, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + sids=speaker, + alpha=alpha, + ) + return audio + + +def export_model_onnx( + model: nn.Module, + model_filename: str, + opset_version: int = 11, +) -> None: + """Export the given generator model to ONNX format. + The exported model has one input: + + - tokens, a tensor of shape (1, T_text); dtype is torch.int64 + + and it has one output: + + - audio, a tensor of shape (1, T'); dtype is torch.float32 + + Args: + model: + The VITS generator. + model_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + noise_scale_dur = torch.tensor([1], dtype=torch.float32) + alpha = torch.tensor([1], dtype=torch.float32) + speaker = torch.tensor([1], dtype=torch.int64) + + torch.onnx.export( + model, + (tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha), + model_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "tokens", + "tokens_lens", + "noise_scale", + "noise_scale_dur", + "speaker", + "alpha", + ], + output_names=["audio"], + dynamic_axes={ + "tokens": {0: "N", 1: "T"}, + "tokens_lens": {0: "N"}, + "audio": {0: "N", 1: "T"}, + "speaker": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "VITS", + "version": "1", + "model_author": "k2-fsa", + "comment": "VITS generator", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=model_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + with open(args.speakers) as f: + speaker_map = {line.strip(): i for i, line in enumerate(f)} + params.num_spks = len(speaker_map) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model = model.generator + model.to("cpu") + model.eval() + + model = OnnxModel(model=model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"generator parameters: {num_param}") + + suffix = f"epoch-{params.epoch}" + + opset_version = 13 + + logging.info("Exporting encoder") + model_filename = params.exp_dir / f"vits-{suffix}.onnx" + export_model_onnx( + model, + model_filename, + opset_version=opset_version, + ) + logging.info(f"Exported generator to {model_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx" + quantize_dynamic( + model_input=model_filename, + model_output=model_filename_int8, + weight_type=QuantType.QUInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/vctk/TTS/vits/flow.py b/egs/vctk/TTS/vits/flow.py new file mode 120000 index 000000000..e65d91ea7 --- /dev/null +++ b/egs/vctk/TTS/vits/flow.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/flow.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/generator.py b/egs/vctk/TTS/vits/generator.py new file mode 120000 index 000000000..611679bfa --- /dev/null +++ b/egs/vctk/TTS/vits/generator.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/generator.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/hifigan.py b/egs/vctk/TTS/vits/hifigan.py new file mode 120000 index 000000000..5ac025de7 --- /dev/null +++ b/egs/vctk/TTS/vits/hifigan.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/hifigan.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/infer.py b/egs/vctk/TTS/vits/infer.py new file mode 100755 index 000000000..06c25f02e --- /dev/null +++ b/egs/vctk/TTS/vits/infer.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script performs model inference on test set. + +Usage: +./vits/infer.py \ + --epoch 1000 \ + --exp-dir ./vits/exp \ + --max-duration 500 +""" + + +import argparse +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Dict, List + +import k2 +import torch +import torch.nn as nn +import torchaudio +from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import VctkTtsDataModule + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + subset: str, + params: AttributeDict, + model: nn.Module, + tokenizer: Tokenizer, + speaker_map: Dict[str, int], +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + + # Background worker save audios to disk. + def _save_worker( + subset: str, + batch_size: int, + cut_ids: List[str], + audio: torch.Tensor, + audio_pred: torch.Tensor, + audio_lens: List[int], + audio_lens_pred: List[int], + ): + for i in range(batch_size): + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"), + audio[i : i + 1, : audio_lens[i]], + sample_rate=params.sampling_rate, + ) + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"), + audio_pred[i : i + 1, : audio_lens_pred[i]], + sample_rate=params.sampling_rate, + ) + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + tokens = batch["tokens"] + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + speakers = ( + torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]) + .int() + .to(device) + ) + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + audio_pred, _, durations = model.inference_batch( + text=tokens, + text_lengths=tokens_lens, + sids=speakers, + ) + audio_pred = audio_pred.detach().cpu() + # convert to samples + audio_lens_pred = ( + (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + ) + + futures.append( + executor.submit( + _save_worker, + subset, + batch_size, + cut_ids, + audio, + audio_pred, + audio_lens, + audio_lens_pred, + ) + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + # return results + for f in futures: + f.result() + + +@torch.no_grad() +def main(): + parser = get_parser() + VctkTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + # we need cut ids to display recognition results. + args.return_cuts = True + vctk = VctkTtsDataModule(args) + speaker_map = vctk.speakers() + params.num_spks = len(speaker_map) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model.to(device) + model.eval() + + num_param_g = sum([p.numel() for p in model.generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in model.discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + test_cuts = vctk.test_cuts() + test_dl = vctk.test_dataloaders(test_cuts) + + valid_cuts = vctk.valid_cuts() + valid_dl = vctk.valid_dataloaders(valid_cuts) + + infer_sets = {"test": test_dl, "valid": valid_dl} + + for subset, dl in infer_sets.items(): + save_wav_dir = params.res_dir / "wav" / subset + save_wav_dir.mkdir(parents=True, exist_ok=True) + + logging.info(f"Processing {subset} set, saving to {save_wav_dir}") + + infer_dataset( + dl=dl, + subset=subset, + params=params, + model=model, + tokenizer=tokenizer, + speaker_map=speaker_map, + ) + + logging.info(f"Wav files are saved to {params.save_wav_dir}") + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/vctk/TTS/vits/loss.py b/egs/vctk/TTS/vits/loss.py new file mode 120000 index 000000000..672e5ff68 --- /dev/null +++ b/egs/vctk/TTS/vits/loss.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/loss.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/monotonic_align b/egs/vctk/TTS/vits/monotonic_align new file mode 120000 index 000000000..71934e7cc --- /dev/null +++ b/egs/vctk/TTS/vits/monotonic_align @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/monotonic_align \ No newline at end of file diff --git a/egs/vctk/TTS/vits/posterior_encoder.py b/egs/vctk/TTS/vits/posterior_encoder.py new file mode 120000 index 000000000..41d64a3a6 --- /dev/null +++ b/egs/vctk/TTS/vits/posterior_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/posterior_encoder.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/residual_coupling.py b/egs/vctk/TTS/vits/residual_coupling.py new file mode 120000 index 000000000..f979adbf0 --- /dev/null +++ b/egs/vctk/TTS/vits/residual_coupling.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/residual_coupling.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py new file mode 100755 index 000000000..757e67fc1 --- /dev/null +++ b/egs/vctk/TTS/vits/test_onnx.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is used to test the exported onnx model by vits/export-onnx.py + +Use the onnx model to generate a wav: +./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt +""" + + +import argparse +import logging +from pathlib import Path + +import onnxruntime as ort +import torch +import torchaudio +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the onnx model.", + ) + + parser.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +class OnnxModel: + def __init__(self, model_filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + + def __call__( + self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor + ) -> torch.Tensor: + """ + Args: + tokens: + A 1-D tensor of shape (1, T) + Returns: + A tensor of shape (1, T') + """ + noise_scale = torch.tensor([0.667], dtype=torch.float32) + noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) + alpha = torch.tensor([1.0], dtype=torch.float32) + + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: tokens.numpy(), + self.model.get_inputs()[1].name: tokens_lens.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: noise_scale_dur.numpy(), + self.model.get_inputs()[4].name: speaker.numpy(), + self.model.get_inputs()[5].name: alpha.numpy(), + }, + )[0] + return torch.from_numpy(out) + + +def main(): + args = get_parser().parse_args() + + tokenizer = Tokenizer(args.tokens) + + with open(args.speakers) as f: + speaker_map = {line.strip(): i for i, line in enumerate(f)} + args.num_spks = len(speaker_map) + + logging.info("About to create onnx model") + model = OnnxModel(args.model_filename) + + text = "I went there to see the land, the people and how their system works, end quote." + tokens = tokenizer.texts_to_token_ids([text]) + tokens = torch.tensor(tokens) # (1, T) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) + speaker = torch.tensor([1], dtype=torch.int64) # (1, ) + audio = model(tokens, tokens_lens, speaker) # (1, T') + + torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) + logging.info("Saved to test_onnx.wav") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/vctk/TTS/vits/text_encoder.py b/egs/vctk/TTS/vits/text_encoder.py new file mode 120000 index 000000000..0efba277e --- /dev/null +++ b/egs/vctk/TTS/vits/text_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/text_encoder.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/tokenizer.py b/egs/vctk/TTS/vits/tokenizer.py new file mode 120000 index 000000000..057b0dc4b --- /dev/null +++ b/egs/vctk/TTS/vits/tokenizer.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/tokenizer.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py new file mode 100755 index 000000000..56f167a17 --- /dev/null +++ b/egs/vctk/TTS/vits/train.py @@ -0,0 +1,1000 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import numpy as np +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from tokenizer import Tokenizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import VctkTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--lr", type=float, default=2.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=20, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 200, + "env_info": get_env_info(), + "sampling_rate": 22050, + "frame_shift": 256, + "frame_length": 1024, + "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length + "n_mels": 80, + "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss + "lambda_mel": 45.0, # loss scaling coefficient for Mel loss + "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss + "lambda_dur": 1.0, # loss scaling coefficient for duration loss + "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + mel_loss_params = { + "n_mels": params.n_mels, + "frame_length": params.frame_length, + "frame_shift": params.frame_shift, + } + generator_params = { + "hidden_channels": 192, + "spks": params.num_spks, + "langs": None, + "spk_embed_dim": None, + "global_channels": 256, + "segment_size": 32, + "text_encoder_attention_heads": 2, + "text_encoder_ffn_expand": 4, + "text_encoder_cnn_module_kernel": 5, + "text_encoder_blocks": 6, + "text_encoder_dropout_rate": 0.1, + "decoder_kernel_size": 7, + "decoder_channels": 512, + "decoder_upsample_scales": [8, 8, 2, 2], + "decoder_upsample_kernel_sizes": [16, 16, 4, 4], + "decoder_resblock_kernel_sizes": [3, 7, 11], + "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "use_weight_norm_in_decoder": True, + "posterior_encoder_kernel_size": 5, + "posterior_encoder_layers": 16, + "posterior_encoder_stacks": 1, + "posterior_encoder_base_dilation": 1, + "posterior_encoder_dropout_rate": 0.0, + "use_weight_norm_in_posterior_encoder": True, + "flow_flows": 4, + "flow_kernel_size": 5, + "flow_base_dilation": 1, + "flow_layers": 4, + "flow_dropout_rate": 0.0, + "use_weight_norm_in_flow": True, + "use_only_mean_in_flow": True, + "stochastic_duration_predictor_kernel_size": 3, + "stochastic_duration_predictor_dropout_rate": 0.5, + "stochastic_duration_predictor_flows": 4, + "stochastic_duration_predictor_dds_conv_layers": 3, + } + model = VITS( + vocab_size=params.vocab_size, + feature_dim=params.feature_dim, + sampling_rate=params.sampling_rate, + generator_params=generator_params, + mel_loss_params=mel_loss_params, + lambda_adv=params.lambda_adv, + lambda_mel=params.lambda_mel, + lambda_feat_match=params.lambda_feat_match, + lambda_dur=params.lambda_dur, + lambda_kl=params.lambda_kl, + ) + return model + + +def prepare_input( + batch: dict, + tokenizer: Tokenizer, + device: torch.device, + speaker_map: Dict[str, int], +): + """Parse batch data""" + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + speakers = ( + torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device) + ) + + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + + return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + speaker_map: Dict[str, int], + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + tokenizer: + Used to convert text to phonemes. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + try: + with autocast(enabled=params.use_fp16): + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(loss_d).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + return_sample=params.batch_idx_train % params.log_interval == 0, + ) + for k, v in stats_g.items(): + if "returned_sample" not in k: + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(loss_g).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + if "returned_sample" in stats_g: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + tb_writer.add_audio( + "train/speech_hat_", + speech_hat_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/speech_", + speech_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_image( + "train/mel_hat_", + plot_feature(mel_hat_), + params.batch_idx_train, + dataformats="HWC", + ) + tb_writer.add_image( + "train/mel_", + plot_feature(mel_), + params.batch_idx_train, + dataformats="HWC", + ) + + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info, (speech_hat, speech) = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + speaker_map=speaker_map, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + tb_writer.add_audio( + "train/valdi_speech_hat", + speech_hat, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/valdi_speech", + speech, + params.batch_idx_train, + params.sampling_rate, + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + speaker_map: Dict[str, int], + world_size: int = 1, + rank: int = 0, +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + returned_sample = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + audio_pred, _, duration = inner_model.inference( + text=tokens[0, : tokens_lens[0].item()], + sids=speakers[0], + ) + audio_pred = audio_pred.data.cpu().numpy() + audio_len_pred = ( + (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + ) + assert audio_len_pred == len(audio_pred), ( + audio_len_pred, + len(audio_pred), + ) + audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() + returned_sample = (audio_pred, audio_gt) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, returned_sample + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + tokenizer: Tokenizer, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + speaker_map: Dict[str, int], + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + vctk = VctkTtsDataModule(args) + + train_cuts = vctk.train_cuts() + speaker_map = vctk.speakers() + params.num_spks = len(speaker_map) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + generator = model.generator + discriminator = model.discriminator + + num_param_g = sum([p.numel() for p in generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + optimizer_d = torch.optim.AdamW( + discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = vctk.train_dataloaders(train_cuts) + + valid_cuts = vctk.valid_cuts() + valid_dl = vctk.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + speaker_map=speaker_map, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + speaker_map=speaker_map, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + # step per epoch + scheduler_g.step() + scheduler_d.step() + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + VctkTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/vctk/TTS/vits/transform.py b/egs/vctk/TTS/vits/transform.py new file mode 120000 index 000000000..962647408 --- /dev/null +++ b/egs/vctk/TTS/vits/transform.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/transform.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/tts_datamodule.py b/egs/vctk/TTS/vits/tts_datamodule.py new file mode 100644 index 000000000..8b2a96b09 --- /dev/null +++ b/egs/vctk/TTS/vits/tts_datamodule.py @@ -0,0 +1,338 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class VctkTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=8, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz") + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz") + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz") + + @lru_cache() + def speakers(self) -> Dict[str, int]: + logging.info("About to get speakers") + with open(self.args.speakers) as f: + speakers = {line.strip(): i for i, line in enumerate(f)} + return speakers diff --git a/egs/vctk/TTS/vits/utils.py b/egs/vctk/TTS/vits/utils.py new file mode 120000 index 000000000..085e764b4 --- /dev/null +++ b/egs/vctk/TTS/vits/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/utils.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/vits.py b/egs/vctk/TTS/vits/vits.py new file mode 120000 index 000000000..1f58cf6fe --- /dev/null +++ b/egs/vctk/TTS/vits/vits.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/vits.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/wavenet.py b/egs/vctk/TTS/vits/wavenet.py new file mode 120000 index 000000000..28f0a78ee --- /dev/null +++ b/egs/vctk/TTS/vits/wavenet.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/wavenet.py \ No newline at end of file diff --git a/requirements-tts.txt b/requirements-tts.txt new file mode 100644 index 000000000..c30e23d54 --- /dev/null +++ b/requirements-tts.txt @@ -0,0 +1,6 @@ +# for TTS recipes +matplotlib==3.8.2 +cython==3.0.6 +numba==0.58.1 +g2p_en==2.1.0 +espnet_tts_frontend==0.0.3 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9502fcbd2..a1a46ae64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,5 @@ tensorboard typeguard dill black==22.3.0 +onnx==1.15.0 +onnxruntime==1.16.3 \ No newline at end of file From b87ed26c09e9f5bb29174dd01f13670fb6124583 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 6 Dec 2023 14:33:45 +0800 Subject: [PATCH 109/113] Normalize dockerfile (#1400) --- docker/torch1.12.1-cuda11.3.dockerfile | 3 +-- docker/torch1.13.0-cuda11.6.dockerfile | 3 +-- docker/torch1.9.0-cuda10.2.dockerfile | 3 +-- docker/torch2.0.0-cuda11.7.dockerfile | 3 +-- docker/torch2.1.0-cuda11.8.dockerfile | 3 +-- docker/torch2.1.0-cuda12.1.dockerfile | 3 +-- 6 files changed, 6 insertions(+), 12 deletions(-) diff --git a/docker/torch1.12.1-cuda11.3.dockerfile b/docker/torch1.12.1-cuda11.3.dockerfile index ed746abe3..deb5715cc 100644 --- a/docker/torch1.12.1-cuda11.3.dockerfile +++ b/docker/torch1.12.1-cuda11.3.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch1.13.0-cuda11.6.dockerfile b/docker/torch1.13.0-cuda11.6.dockerfile index 9657866e5..afc6c1b84 100644 --- a/docker/torch1.13.0-cuda11.6.dockerfile +++ b/docker/torch1.13.0-cuda11.6.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch1.9.0-cuda10.2.dockerfile b/docker/torch1.9.0-cuda10.2.dockerfile index a92af7ad0..9ff225b54 100644 --- a/docker/torch1.9.0-cuda10.2.dockerfile +++ b/docker/torch1.9.0-cuda10.2.dockerfile @@ -25,7 +25,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -58,7 +58,6 @@ RUN pip uninstall -y tqdm && \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile index 07296e6f0..db8076560 100644 --- a/docker/torch2.0.0-cuda11.7.dockerfile +++ b/docker/torch2.0.0-cuda11.7.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch2.1.0-cuda11.8.dockerfile b/docker/torch2.1.0-cuda11.8.dockerfile index e500e9a6a..b006b0d96 100644 --- a/docker/torch2.1.0-cuda11.8.dockerfile +++ b/docker/torch2.1.0-cuda11.8.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch2.1.0-cuda12.1.dockerfile b/docker/torch2.1.0-cuda12.1.dockerfile index c3f12323e..1b078dc22 100644 --- a/docker/torch2.1.0-cuda12.1.dockerfile +++ b/docker/torch2.1.0-cuda12.1.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ From bda72f86fffe591d334630da522dba4cf5c66341 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 8 Dec 2023 06:32:40 +0800 Subject: [PATCH 110/113] minor adjustments to the VITS recipes for onnx runtime (#1405) --- egs/ljspeech/TTS/vits/export-onnx.py | 4 ++-- egs/ljspeech/TTS/vits/test_onnx.py | 4 ++-- egs/vctk/TTS/vits/export-onnx.py | 4 ++-- egs/vctk/TTS/vits/test_onnx.py | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 2068adeea..bca6aec99 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -176,7 +176,7 @@ def export_model_onnx( torch.onnx.export( model, - (tokens, tokens_lens, noise_scale, noise_scale_dur, alpha), + (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur), model_filename, verbose=False, opset_version=opset_version, @@ -184,8 +184,8 @@ def export_model_onnx( "tokens", "tokens_lens", "noise_scale", - "noise_scale_dur", "alpha", + "noise_scale_dur", ], output_names=["audio"], dynamic_axes={ diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py index 686fee2a0..fcbc1d663 100755 --- a/egs/ljspeech/TTS/vits/test_onnx.py +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -92,8 +92,8 @@ class OnnxModel: self.model.get_inputs()[0].name: tokens.numpy(), self.model.get_inputs()[1].name: tokens_lens.numpy(), self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: noise_scale_dur.numpy(), - self.model.get_inputs()[4].name: alpha.numpy(), + self.model.get_inputs()[3].name: alpha.numpy(), + self.model.get_inputs()[4].name: noise_scale_dur.numpy(), }, )[0] return torch.from_numpy(out) diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py index 7c9664cc1..cfc74fd0a 100755 --- a/egs/vctk/TTS/vits/export-onnx.py +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -187,7 +187,7 @@ def export_model_onnx( torch.onnx.export( model, - (tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha), + (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur, speaker), model_filename, verbose=False, opset_version=opset_version, @@ -195,9 +195,9 @@ def export_model_onnx( "tokens", "tokens_lens", "noise_scale", + "alpha", "noise_scale_dur", "speaker", - "alpha", ], output_names=["audio"], dynamic_axes={ diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py index 757e67fc1..d85c0a27b 100755 --- a/egs/vctk/TTS/vits/test_onnx.py +++ b/egs/vctk/TTS/vits/test_onnx.py @@ -101,9 +101,9 @@ class OnnxModel: self.model.get_inputs()[0].name: tokens.numpy(), self.model.get_inputs()[1].name: tokens_lens.numpy(), self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: noise_scale_dur.numpy(), - self.model.get_inputs()[4].name: speaker.numpy(), - self.model.get_inputs()[5].name: alpha.numpy(), + self.model.get_inputs()[3].name: alpha.numpy(), + self.model.get_inputs()[4].name: noise_scale_dur.numpy(), + self.model.get_inputs()[5].name: speaker.numpy(), }, )[0] return torch.from_numpy(out) From e9ec827de76856e38af7a884b878ca3a84f64bb9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 8 Dec 2023 14:29:24 +0800 Subject: [PATCH 111/113] Rename zipformer2 to zipformer_for_ncnn_export_only to avoid confusion. (#1407) --- .../do_not_use_it_directly.py | 2 +- .../ASR/pruned_transducer_stateless7_streaming/zipformer2.py | 1 - .../zipformer_for_ncnn_export_only.py | 1 + .../do_not_use_it_directly.py | 2 +- .../ASR/pruned_transducer_stateless7_streaming/zipformer2.py | 1 - .../zipformer_for_ncnn_export_only.py | 1 + .../do_not_use_it_directly.py | 2 +- .../ASR/pruned_transducer_stateless7_streaming/zipformer2.py | 1 - .../zipformer_for_ncnn_export_only.py | 1 + .../do_not_use_it_directly.py | 2 +- .../{zipformer2.py => zipformer_for_ncnn_export_only.py} | 0 .../pruned_transducer_stateless7_streaming_multi/zipformer2.py | 1 - 12 files changed, 7 insertions(+), 8 deletions(-) delete mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py delete mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py delete mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py rename egs/librispeech/ASR/pruned_transducer_stateless7_streaming/{zipformer2.py => zipformer_for_ncnn_export_only.py} (100%) delete mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 3c13c19c6..0fba3b58f 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -66,7 +66,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer2 import Zipformer +from zipformer_for_ncnn_export_only import Zipformer from icefall import diagnostics from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py deleted file mode 120000 index 12dbda888..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py new file mode 120000 index 000000000..d301e1f9b --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 61a3f27db..0426bc9a3 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -67,7 +67,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer2 import Zipformer +from zipformer_for_ncnn_export_only import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py deleted file mode 120000 index 12dbda888..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py new file mode 120000 index 000000000..d301e1f9b --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index acde72d80..685f6ece6 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -70,7 +70,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer2 import Zipformer +from zipformer_for_ncnn_export_only import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py deleted file mode 120000 index 12dbda888..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py new file mode 120000 index 000000000..d301e1f9b --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index cd26db6f3..9a6d2155b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -69,7 +69,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer2 import Zipformer +from zipformer_for_ncnn_export_only import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py similarity index 100% rename from egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py rename to egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py deleted file mode 120000 index d3625f478..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file From df56aff31ea9b95aa3d9672398a0771dcb8eacc5 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 8 Dec 2023 21:11:31 +0800 Subject: [PATCH 112/113] minor fixes to the vits onnx exportation scripts (#1408) --- egs/ljspeech/TTS/vits/export-onnx.py | 2 +- egs/vctk/TTS/vits/export-onnx.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index bca6aec99..36a9de27f 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -115,8 +115,8 @@ class OnnxModel(nn.Module): tokens: torch.Tensor, tokens_lens: torch.Tensor, noise_scale: float = 0.667, - noise_scale_dur: float = 0.8, alpha: float = 1.0, + noise_scale_dur: float = 0.8, ) -> Tuple[torch.Tensor, torch.Tensor]: """Please see the help information of VITS.inference_batch diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py index cfc74fd0a..667ac284b 100755 --- a/egs/vctk/TTS/vits/export-onnx.py +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -121,9 +121,9 @@ class OnnxModel(nn.Module): tokens: torch.Tensor, tokens_lens: torch.Tensor, noise_scale: float = 0.667, + alpha: float = 1.0, noise_scale_dur: float = 0.8, speaker: int = 20, - alpha: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Please see the help information of VITS.inference_batch From b0f70c9d042da734a5df988b98412e5def6b8072 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 10 Dec 2023 11:38:39 +0800 Subject: [PATCH 113/113] Fix torch.jit.script() export for pruned_transducer_stateless2 (#1410) --- egs/librispeech/ASR/pruned_transducer_stateless2/export.py | 2 ++ egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py | 1 + .../ASR/pruned_transducer_stateless2/scaling_converter.py | 1 + 3 files changed, 4 insertions(+) create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index e02afa892..e2db98f73 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -49,6 +49,7 @@ from pathlib import Path import k2 import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint @@ -198,6 +199,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py new file mode 120000 index 000000000..4f377cd01 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file