diff --git a/docs/source/recipes/aishell.rst b/docs/source/recipes/aishell.rst new file mode 100644 index 000000000..71ccaa1fc --- /dev/null +++ b/docs/source/recipes/aishell.rst @@ -0,0 +1,10 @@ +Aishell +======= + +We provide the following models for the Aishell dataset: + +.. toctree:: + :maxdepth: 2 + + aishell/conformer_ctc + aishell/tdnn_lstm_ctc diff --git a/docs/source/recipes/aishell/conformer_ctc.rst b/docs/source/recipes/aishell/conformer_ctc.rst new file mode 100644 index 000000000..96ede5b68 --- /dev/null +++ b/docs/source/recipes/aishell/conformer_ctc.rst @@ -0,0 +1,573 @@ +Confromer CTC +============= + +This tutorial shows you how to run a conformer ctc model +with the `Aishell `_ dataset. + + +.. HINT:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe. + +In this tutorial, you will learn: + + - (1) How to prepare data for training and decoding + - (2) How to start the training, either with a single GPU or multiple GPUs + - (3) How to do decoding after training, with 1best and attention decoder rescoring + - (4) How to use a pre-trained model, provided by us + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. HINT:: + + If you have pre-downloaded the `Aishell `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/aishell`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. HINT:: + + A 3-gram language model will be downloaded from huggingface, we assume you have + intalled and initialized ``git-lfs``. If not, you could install ``git-lfs`` by + + .. code-block:: bash + + $ sudo apt-get install git-lfs + $ git-lfs install + + If you don't have the ``sudo`` permission, you could download the + `git-lfs binary `_ here, then add it to you ``PATH``. + +.. NOTE:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + + +Training +-------- + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ ./conformer_ctc/train.py --help + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./conformer_ctc/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt`` + in the folder ``./conformer_ctc/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./conformer_ctc/train.py --start-epoch 10`` loads the + checkpoint ``./conformer_ctc/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/aishell/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./conformer_ctc/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/aishell/ASR + $ ./conformer_ctc/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/aishell/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./conformer_ctc/train.py --world-size 1 + + .. CAUTION:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. For instance, if + your are using V100 NVIDIA GPU, we recommend you to set it to ``200``. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., weight decay, +number of warmup steps, results dir, etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`conformer_ctc/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./conformer_ctc/train.py`` directly. + + +.. CAUTION:: + + The training set is perturbed by speed with two factors: 0.9 and 1.1. + Each epoch actually processes ``3x150 == 450`` hours of data. + + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``conformer_ctc/exp``. +You will find the following files in that directory: + + - ``epoch-0.pt``, ``epoch-1.pt``, ... + + These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./conformer_ctc/train.py --start-epoch 11 + + - ``tensorboard/`` + + This folder contains TensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd conformer_ctc/exp/tensorboard + $ tensorboard dev upload --logdir . --description "Conformer CTC training for Aishell with icefall" + + It will print something like below: + + .. code-block:: + + TensorFlow installation not found - running with reduced feature set. + Upload started and will continue reading any new data as it's added to the logdir. + + To stop uploading, press Ctrl-C. + + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/qvNrx6JIQAaN5Ly3uQotrg/ + + [2021-09-12T16:41:16] Started scanning logdir. + [2021-09-12T16:42:17] Total uploaded: 125346 scalars, 0 tensors, 0 binary objects + Listening for new data in logdir... + + Note there is a URL in the above output, click it and you will see + the following screenshot: + + .. figure:: images/aishell-conformer-ctc-tensorboard-log.jpg + :width: 600 + :alt: TensorBoard screenshot + :align: center + :target: https://tensorboard.dev/experiment/qvNrx6JIQAaN5Ly3uQotrg/ + + TensorBoard screenshot. + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage examples +~~~~~~~~~~~~~~ + +The following shows typical use cases: + +**Case 1** +^^^^^^^^^^ + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ ./conformer_ctc/train.py --max-duration 200 + +It uses ``--max-duration`` of 200 to avoid OOM. + + +**Case 2** +^^^^^^^^^^ + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ export CUDA_VISIBLE_DEVICES="0,3" + $ ./conformer_ctc/train.py --world-size 2 + +It uses GPU 0 and GPU 3 for DDP training. + +**Case 3** +^^^^^^^^^^ + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ ./conformer_ctc/train.py --num-epochs 10 --start-epoch 3 + +It loads checkpoint ``./conformer_ctc/exp/epoch-2.pt`` and starts +training from epoch 3. Also, it trains for 10 epochs. + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ ./conformer_ctc/decode.py --help + +shows the options for decoding. + +The commonly used options are: + + - ``--method`` + + This specifies the decoding method. + + The following command uses attention decoder for rescoring: + + .. code-block:: + + $ cd egs/aishell/ASR + $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --lattice-score-scale 0.5 + + - ``--lattice-score-scale`` + + It is used to scale down lattice scores so that there are more unique + paths for rescoring. + + - ``--max-duration`` + + It has the same meaning as the one during training. A larger + value may cause OOM. + +Pre-trained Model +----------------- + +We have uploaded a pre-trained model to +``_. + +We describe how to use the pre-trained model to transcribe a sound file or +multiple sound files in the following. + +Install kaldifeat +~~~~~~~~~~~~~~~~~ + +`kaldifeat `_ is used to +extract features for a single sound file or multiple sound files +at the same time. + +Please refer to ``_ for installation. + +Download the pre-trained model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The following commands describe how to download the pre-trained model: + +.. code-block:: + + $ cd egs/aishell/ASR + $ mkdir tmp + $ cd tmp + $ git lfs install + $ git clone https://huggingface.co/pkufool/icefall_asr_aishell_conformer_ctc + +.. CAUTION:: + + You have to use ``git lfs`` to download the pre-trained model. + +.. CAUTION:: + + In order to use this pre-trained model, your k2 version has to be v1.7 or later. + +After downloading, you will have the following files: + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ tree tmp + +.. code-block:: bash + + tmp/ + `-- icefall_asr_aishell_conformer_ctc + |-- README.md + |-- data + | `-- lang_char + | |-- HLG.pt + | |-- tokens.txt + | `-- words.txt + |-- exp + | `-- pretrained.pt + `-- test_waves + |-- BAC009S0764W0121.wav + |-- BAC009S0764W0122.wav + |-- BAC009S0764W0123.wav + `-- trans.txt + + 5 directories, 9 files + +**File descriptions**: + + - ``data/lang_char/HLG.pt`` + + It is the decoding graph. + + - ``data/lang_char/tokens.txt`` + + It contains tokens and their IDs. + Provided only for convenience so that you can look up the SOS/EOS ID easily. + + - ``data/lang_char/words.txt`` + + It contains words and their IDs. + + - ``exp/pretrained.pt`` + + It contains pre-trained model parameters, obtained by averaging + checkpoints from ``epoch-18.pt`` to ``epoch-40.pt``. + Note: We have removed optimizer ``state_dict`` to reduce file size. + + - ``test_waves/*.wav`` + + It contains some test sound files from Aishell ``test`` dataset. + + - ``test_waves/trans.txt`` + + It contains the reference transcripts for the sound files in `test_waves/`. + +The information of the test sound files is listed below: + +.. code-block:: bash + + $ soxi tmp/icefall_asr_aishell_conformer_ctc/test_wavs/*.wav + + Input File : 'tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav' + Channels : 1 + Sample Rate : 16000 + Precision : 16-bit + Duration : 00:00:04.20 = 67263 samples ~ 315.295 CDDA sectors + File Size : 135k + Bit Rate : 256k + Sample Encoding: 16-bit Signed Integer PCM + + + Input File : 'tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav' + Channels : 1 + Sample Rate : 16000 + Precision : 16-bit + Duration : 00:00:04.12 = 65840 samples ~ 308.625 CDDA sectors + File Size : 132k + Bit Rate : 256k + Sample Encoding: 16-bit Signed Integer PCM + + + Input File : 'tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav' + Channels : 1 + Sample Rate : 16000 + Precision : 16-bit + Duration : 00:00:04.00 = 64000 samples ~ 300 CDDA sectors + File Size : 128k + Bit Rate : 256k + Sample Encoding: 16-bit Signed Integer PCM + + Total Duration of 3 files: 00:00:12.32 + +Usage +~~~~~ + +.. code-block:: + + $ cd egs/aishell/ASR + $ ./conformer_ctc/pretrained.py --help + +displays the help information. + +It supports two decoding methods: + + - HLG decoding + - HLG + attention decoder rescoring + +HLG decoding +^^^^^^^^^^^^ + +HLG decoding uses the best path of the decoding lattice as the decoding result. + +The command to run HLG decoding is: + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ ./conformer_ctc/pretrained.py \ + --checkpoint ./tmp/icefall_asr_aishell_conformer_ctc/exp/pretrained.pt \ + --words-file ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/words.txt \ + --HLG ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt \ + --method 1best \ + ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0121.wav \ + ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0122.wav \ + ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav + +The output is given below: + +.. code-block:: + + 2021-09-13 10:46:59,842 INFO [pretrained.py:219] device: cuda:0 + 2021-09-13 10:46:59,842 INFO [pretrained.py:221] Creating model + 2021-09-13 10:47:54,682 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt + 2021-09-13 10:48:46,111 INFO [pretrained.py:245] Constructing Fbank computer + 2021-09-13 10:48:46,113 INFO [pretrained.py:255] Reading sound files: ['./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav', './tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav', './tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav'] + 2021-09-13 10:48:46,368 INFO [pretrained.py:262] Decoding started + 2021-09-13 10:48:46,847 INFO [pretrained.py:291] Use HLG decoding + 2021-09-13 10:48:47,176 INFO [pretrained.py:322] + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav: + 甚至 出现 交易 几乎 停止 的 情况 + + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav: + 一二 线 城市 虽然 也 处于 调整 中 + + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav: + 但 因为 聚集 了 过多 公共 资源 + + + 2021-09-13 10:48:47,177 INFO [pretrained.py:324] Decoding Done + +HLG decoding + attention decoder rescoring +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +It extracts n paths from the lattice, recores the extracted paths with +an attention decoder. The path with the highest score is the decoding result. + +The command to run HLG decoding + attention decoder rescoring is: + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ ./conformer_ctc/pretrained.py \ + --checkpoint ./tmp/icefall_asr_aishell_conformer_ctc/exp/pretrained.pt \ + --words-file ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/words.txt \ + --HLG ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt \ + --method attention-decoder \ + ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0121.wav \ + ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0122.wav \ + ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav + +The output is below: + +.. code-block:: + + 2021-09-13 11:02:15,852 INFO [pretrained.py:219] device: cuda:0 + 2021-09-13 11:02:15,852 INFO [pretrained.py:221] Creating model + 2021-09-13 11:02:22,292 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt + 2021-09-13 11:02:27,060 INFO [pretrained.py:245] Constructing Fbank computer + 2021-09-13 11:02:27,062 INFO [pretrained.py:255] Reading sound files: ['./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav', './tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav', './tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav'] + 2021-09-13 11:02:27,129 INFO [pretrained.py:261] Decoding started + 2021-09-13 11:02:27,241 INFO [pretrained.py:295] Use HLG + attention decoder rescoring + 2021-09-13 11:02:27,823 INFO [pretrained.py:318] + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav: + 甚至 出现 交易 几乎 停止 的 情况 + + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav: + 一二 线 城市 虽然 也 处于 调整 中 + + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav: + 但 因为 聚集 了 过多 公共 资源 + + + 2021-09-13 11:02:27,823 INFO [pretrained.py:320] Decoding Done + +Colab notebook +-------------- + +We do provide a colab notebook for this recipe showing how to use a pre-trained model. + +|aishell asr conformer ctc colab notebook| + +.. |aishell asr conformer ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg + :target: https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC + +.. HINT:: + + Due to limited memory provided by Colab, you have to upgrade to Colab Pro to + run ``HLG decoding + attention decoder rescoring``. + Otherwise, you can only run ``HLG decoding`` with Colab. + +**Congratulations!** You have finished the aishell ASR recipe with +conformer CTC models in ``icefall``. diff --git a/docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg b/docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg new file mode 100644 index 000000000..47f7d18a7 Binary files /dev/null and b/docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg differ diff --git a/docs/source/recipes/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg b/docs/source/recipes/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg new file mode 100644 index 000000000..b31db3ab5 Binary files /dev/null and b/docs/source/recipes/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg differ diff --git a/docs/source/recipes/aishell/tdnn_lstm_ctc.rst b/docs/source/recipes/aishell/tdnn_lstm_ctc.rst new file mode 100644 index 000000000..e9b0ea656 --- /dev/null +++ b/docs/source/recipes/aishell/tdnn_lstm_ctc.rst @@ -0,0 +1,504 @@ +TDNN-LSTM CTC +============= + +This tutorial shows you how to run a tdnn-lstm ctc model +with the `Aishell `_ dataset. + + +.. HINT:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe. + +In this tutorial, you will learn: + + - (1) How to prepare data for training and decoding + - (2) How to start the training, either with a single GPU or multiple GPUs + - (3) How to do decoding after training. + - (4) How to use a pre-trained model, provided by us + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. HINT:: + + If you have pre-downloaded the `Aishell `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/aishell`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. HINT:: + + A 3-gram language model will be downloaded from huggingface, we assume you have + intalled and initialized ``git-lfs``. If not, you could install ``git-lfs`` by + + .. code-block:: bash + + $ sudo apt-get install git-lfs + $ git-lfs install + + If you don't have the ``sudo`` permission, you could download the + `git-lfs binary `_ here, then add it to you ``PATH``. + +.. NOTE:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + + +Training +-------- + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ ./tdnn_lstm_ctc/train.py --help + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./tdnn_lstm_ctc/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt`` + in the folder ``./tdnn_lstm_ctc/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./tdnn_lstm_ctc/train.py --start-epoch 10`` loads the + checkpoint ``./tdnn_lstm_ctc/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/aishell/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./tdnn_lstm_ctc/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/aishell/ASR + $ ./tdnn_lstm_ctc/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/aishell/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./tdnn_lstm_ctc/train.py --world-size 1 + + .. CAUTION:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. For instance, if + your are using V100 NVIDIA GPU, we recommend you to set it to ``2000``. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., weight decay, +number of warmup steps, results dir, etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`tdnn_lstm_ctc/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./tdnn_lstm_ctc/train.py`` directly. + + +.. CAUTION:: + + The training set is perturbed by speed with two factors: 0.9 and 1.1. + Each epoch actually processes ``3x150 == 450`` hours of data. + + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``tdnn_lstm_ctc/exp``. +You will find the following files in that directory: + + - ``epoch-0.pt``, ``epoch-1.pt``, ... + + These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./tdnn_lstm_ctc/train.py --start-epoch 11 + + - ``tensorboard/`` + + This folder contains TensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd tdnn_lstm_ctc/exp/tensorboard + $ tensorboard dev upload --logdir . --description "TDNN-LSTM CTC training for Aishell with icefall" + + It will print something like below: + + .. code-block:: + + TensorFlow installation not found - running with reduced feature set. + Upload started and will continue reading any new data as it's added to the logdir. + + To stop uploading, press Ctrl-C. + + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/LJI9MWUORLOw3jkdhxwk8A/ + + [2021-09-13T11:59:23] Started scanning logdir. + [2021-09-13T11:59:24] Total uploaded: 4454 scalars, 0 tensors, 0 binary objects + Listening for new data in logdir... + + Note there is a URL in the above output, click it and you will see + the following screenshot: + + .. figure:: images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg + :width: 600 + :alt: TensorBoard screenshot + :align: center + :target: https://tensorboard.dev/experiment/LJI9MWUORLOw3jkdhxwk8A/ + + TensorBoard screenshot. + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage examples +~~~~~~~~~~~~~~ + +The following shows typical use cases: + +**Case 1** +^^^^^^^^^^ + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ export CUDA_VISIBLE_DEVICES="0,3" + $ ./tdnn_lstm_ctc/train.py --world-size 2 + +It uses GPU 0 and GPU 3 for DDP training. + +**Case 2** +^^^^^^^^^^ + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ ./tdnn_lstm_ctc/train.py --num-epochs 10 --start-epoch 3 + +It loads checkpoint ``./tdnn_lstm_ctc/exp/epoch-2.pt`` and starts +training from epoch 3. Also, it trains for 10 epochs. + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ ./tdnn_lstm_ctc/decode.py --help + +shows the options for decoding. + +The commonly used options are: + + - ``--method`` + + This specifies the decoding method. + + The following command uses attention decoder for rescoring: + + .. code-block:: + + $ cd egs/aishell/ASR + $ ./tdnn_lstm_ctc/decode.py --method 1best --max-duration 100 + + - ``--max-duration`` + + It has the same meaning as the one during training. A larger + value may cause OOM. + +Pre-trained Model +----------------- + +We have uploaded a pre-trained model to +``_. + +We describe how to use the pre-trained model to transcribe a sound file or +multiple sound files in the following. + +Install kaldifeat +~~~~~~~~~~~~~~~~~ + +`kaldifeat `_ is used to +extract features for a single sound file or multiple sound files +at the same time. + +Please refer to ``_ for installation. + +Download the pre-trained model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The following commands describe how to download the pre-trained model: + +.. code-block:: + + $ cd egs/aishell/ASR + $ mkdir tmp + $ cd tmp + $ git lfs install + $ git clone https://huggingface.co/pkufool/icefall_asr_aishell_tdnn_lstm_ctc + +.. CAUTION:: + + You have to use ``git lfs`` to download the pre-trained model. + +.. CAUTION:: + + In order to use this pre-trained model, your k2 version has to be v1.7 or later. + +After downloading, you will have the following files: + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ tree tmp + +.. code-block:: bash + + tmp/ + `-- icefall_asr_aishell_tdnn_lstm_ctc + |-- README.md + |-- data + | `-- lang_phone + | |-- HLG.pt + | |-- tokens.txt + | `-- words.txt + |-- exp + | `-- pretrained.pt + `-- test_waves + |-- BAC009S0764W0121.wav + |-- BAC009S0764W0122.wav + |-- BAC009S0764W0123.wav + `-- trans.txt + + 5 directories, 9 files + +**File descriptions**: + + - ``data/lang_phone/HLG.pt`` + + It is the decoding graph. + + - ``data/lang_phone/tokens.txt`` + + It contains tokens and their IDs. + Provided only for convenience so that you can look up the SOS/EOS ID easily. + + - ``data/lang_phone/words.txt`` + + It contains words and their IDs. + + - ``exp/pretrained.pt`` + + It contains pre-trained model parameters, obtained by averaging + checkpoints from ``epoch-18.pt`` to ``epoch-40.pt``. + Note: We have removed optimizer ``state_dict`` to reduce file size. + + - ``test_waves/*.wav`` + + It contains some test sound files from Aishell ``test`` dataset. + + - ``test_waves/trans.txt`` + + It contains the reference transcripts for the sound files in `test_waves/`. + +The information of the test sound files is listed below: + +.. code-block:: bash + + $ soxi tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_wavs/*.wav + + Input File : 'tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0121.wav' + Channels : 1 + Sample Rate : 16000 + Precision : 16-bit + Duration : 00:00:04.20 = 67263 samples ~ 315.295 CDDA sectors + File Size : 135k + Bit Rate : 256k + Sample Encoding: 16-bit Signed Integer PCM + + + Input File : 'tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0122.wav' + Channels : 1 + Sample Rate : 16000 + Precision : 16-bit + Duration : 00:00:04.12 = 65840 samples ~ 308.625 CDDA sectors + File Size : 132k + Bit Rate : 256k + Sample Encoding: 16-bit Signed Integer PCM + + + Input File : 'tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0123.wav' + Channels : 1 + Sample Rate : 16000 + Precision : 16-bit + Duration : 00:00:04.00 = 64000 samples ~ 300 CDDA sectors + File Size : 128k + Bit Rate : 256k + Sample Encoding: 16-bit Signed Integer PCM + + Total Duration of 3 files: 00:00:12.32 + +Usage +~~~~~ + +.. code-block:: + + $ cd egs/aishell/ASR + $ ./tdnn_lstm_ctc/pretrained.py --help + +displays the help information. + + +HLG decoding +^^^^^^^^^^^^ + +HLG decoding uses the best path of the decoding lattice as the decoding result. + +The command to run HLG decoding is: + +.. code-block:: bash + + $ cd egs/aishell/ASR + $ ./tdnn_lstm_ctc/pretrained.py \ + --checkpoint ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/exp/pretrained.pt \ + --words-file ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/data/lang_phone/words.txt \ + --HLG ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/data/lang_phone/HLG.pt \ + --method 1best \ + ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_wavs/BAC009S0764W0121.wav \ + ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_wavs/BAC009S0764W0122.wav \ + ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_wavs/BAC009S0764W0123.wav + +The output is given below: + +.. code-block:: + + 2021-09-13 15:00:55,858 INFO [pretrained.py:140] device: cuda:0 + 2021-09-13 15:00:55,858 INFO [pretrained.py:142] Creating model + 2021-09-13 15:01:05,389 INFO [pretrained.py:154] Loading HLG from ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/data/lang_phone/HLG.pt + 2021-09-13 15:01:06,531 INFO [pretrained.py:161] Constructing Fbank computer + 2021-09-13 15:01:06,536 INFO [pretrained.py:171] Reading sound files: ['./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0121.wav', './tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0122.wav', './tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0123.wav'] + 2021-09-13 15:01:06,539 INFO [pretrained.py:177] Decoding started + 2021-09-13 15:01:06,917 INFO [pretrained.py:207] Use HLG decoding + 2021-09-13 15:01:07,129 INFO [pretrained.py:220] + ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0121.wav: + 甚至 出现 交易 几乎 停滞 的 情况 + + ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0122.wav: + 一二 线 城市 虽然 也 处于 调整 中 + + ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0123.wav: + 但 因为 聚集 了 过多 公共 资源 + + + 2021-09-13 15:01:07,129 INFO [pretrained.py:222] Decoding Done + + +Colab notebook +-------------- + +We do provide a colab notebook for this recipe showing how to use a pre-trained model. + +|aishell asr conformer ctc colab notebook| + +.. |aishell asr conformer ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg + :target: https://colab.research.google.com/drive/1qULaGvXq7PCu_P61oubfz9b53JzY4H3z + +**Congratulations!** You have finished the aishell ASR recipe with +TDNN-LSTM CTC models in ``icefall``. diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index 36f8dfc39..ab81e5875 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -15,3 +15,5 @@ We may add recipes for other tasks as well in the future. yesno librispeech + + aishell diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 41220900e..5cbe5d213 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -1,3 +1,56 @@ ## Results -Adding soon... +### Aishell training results (Conformer-CTC) +#### 2021-09-13 +(Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/30 + +Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_aishell_conformer_ctc + +The best decoding results (CER) are listed below, we got this results by averaging models from epoch 23 to 40, and using `attention-decoder` decoder with num_paths equals to 100. + +||test| +|--|--| +|CER| 4.74% | + +To get more unique paths, we scaled the lattice.scores with 0.5 (see https://github.com/k2-fsa/icefall/pull/10#discussion_r690951662 for more details), we searched the lm_score_scale and attention_score_scale for best results, the scales that produced the CER above are also listed below. + +||lm_scale|attention_scale| +|--|--|--| +|test|0.3|0.9| + +You can use the following commands to reproduce our results: + +```bash +git clone https://github.com/k2-fsa/icefall +cd icefall + +cd egs/aishell/ASR +./prepare.sh + +export CUDA_VISIBLE_DEVICES="0,1" +python conformer_ctc/train.py --bucketing-sampler False \ + --concatenate-cuts False \ + --max-duration 200 \ + --world-size 2 + +python conformer_ctc/decode.py --lattice-score-scale 0.5 \ + --epoch 40 \ + --avg 18 \ + --method attention-decoder \ + --max-duration 50 \ + --num-paths 100 +``` + +### Aishell training results (Tdnn-Lstm) +#### 2021-09-13 + +(Wei Kang): Result of phone based Tdnn-Lstm model, https://github.com/k2-fsa/icefall/pull/30 + +Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_aishell_conformer_ctc_lstm_ctc + +The best decoding results (CER) are listed below, we got this results by averaging models from epoch 19 to 8, and using `1best` decoding method. + +||test| +|--|--| +|CER| 10.16% | + diff --git a/egs/aishell/ASR/conformer_ctc/README.md b/egs/aishell/ASR/conformer_ctc/README.md new file mode 100644 index 000000000..50596ee92 --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/README.md @@ -0,0 +1,4 @@ + +Please visit + +for how to run this recipe. diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index b4b113d4b..20a8f7b3a 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -45,6 +45,7 @@ from icefall.utils import ( get_texts, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -110,6 +111,17 @@ def get_parser(): """, ) + parser.add_argument( + "--export", + type=str2bool, + default=False, + help="""When enabled, the averaged model is saved to + conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. + pretrained.pt contains a dict {"model": model.state_dict()}, + which can be loaded by `icefall.checkpoint.load_checkpoint()`. + """, + ) + return parser @@ -364,12 +376,13 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - results_tmp = [] + # we compute CER for aishell dataset. + results_char = [] for res in results: - results_tmp.append((list("".join(res[0])), list("".join(res[1])))) + results_char.append((list("".join(res[0])), list("".join(res[1])))) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_tmp, enable_log=enable_log + f, f"{test_set_name}-{key}", results_char, enable_log=enable_log ) test_set_wers[key] = wer @@ -379,13 +392,13 @@ def save_results( ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" with open(errs_info, "w") as f: - print("settings\tWER", file=f) + print("settings\tCER", file=f) for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: s += "{}\t{}{}\n".format(key, val, note) @@ -457,6 +470,13 @@ def main(): logging.info(f"averaging {filenames}") model.load_state_dict(average_checkpoints(filenames)) + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + return + model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py new file mode 100755 index 000000000..846681f00 --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/pretrained.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from conformer import Conformer +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_attention_decoder, +) +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) attention-decoder - Extract n paths from the rescored + lattice and use the transformer attention decoder for + rescoring. + We call it HLG decoding + n-gram LM rescoring + attention + decoder rescoring. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.3, + help=""" + Used only when method is attention-decoder. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--attention-decoder-scale", + type=float, + default=0.9, + help=""" + Used only when method is attention-decoder. + It specifies the scale for attention decoder scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--lattice-score-scale", + type=float, + default=0.5, + help=""" + Used only when method is attention-decoder. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--sos-id", + type=float, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the SOS token. + """, + ) + + parser.add_argument( + "--eos-id", + type=float, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the EOS token. + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "sample_rate": 16000, + "num_classes": 4336, + # parameters for conformer + "subsampling_factor": 4, + "feature_dim": 80, + "nhead": 4, + "attention_dim": 512, + "num_decoder_layers": 6, + "vgg_frontend": False, + "use_feat_batchnorm": True, + # parameters for deocding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=params.num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + model.to(device) + model.eval() + + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + # Note: We don't use key padding mask for attention during decoding + with torch.no_grad(): + nnet_output, memory, memory_key_padding_mask = model(features) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + HLG=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "attention-decoder": + logging.info("Use HLG + attention decoder rescoring") + best_path_dict = rescore_with_attention_decoder( + lattice=lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + scale=params.lattice_score_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/local/compile_hlg.py b/egs/aishell/ASR/local/compile_hlg.py index 19a1ddd23..407fb7d88 100755 --- a/egs/aishell/ASR/local/compile_hlg.py +++ b/egs/aishell/ASR/local/compile_hlg.py @@ -102,14 +102,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG.labels[LG.labels >= first_token_disambig_id] = 0 - assert isinstance(LG.aux_labels, k2.RaggedInt) - LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0 + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0 LG = k2.remove_epsilon(LG) logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") LG = k2.connect(LG) - LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py index c9437b414..568da3811 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py @@ -66,6 +66,26 @@ def get_parser(): "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", ) + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Supported values are: + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + """, + ) + parser.add_argument( + "--num-paths", + type=int, + default=30, + help="""Number of paths for n-best based decoding method. + Used only when "method" is nbest. + """, + ) parser.add_argument( "--export", type=str2bool, @@ -82,22 +102,18 @@ def get_parser(): def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("tdnn_lstm_ctc/exp_lr1e-4/"), + "exp_dir": Path("tdnn_lstm_ctc/exp/"), "lang_dir": Path("data/lang_phone"), "lm_dir": Path("data/lm"), - "feature_dim": 80, + # parameters for tdnn_lstm_ctc "subsampling_factor": 3, + "feature_dim": 80, + # parameters for decoding "search_beam": 20, - "output_beam": 5, + "output_beam": 7, "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, - # Possible values for method: - # - 1best - # - nbest - "method": "1best", - # num_paths is used when method is "nbest" - "num_paths": 30, } ) return params @@ -274,23 +290,24 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - results_tmp = [] + # We compute CER for aishell dataset. + results_char = [] for res in results: - results_tmp.append((list("".join(res[0])), list("".join(res[1])))) + results_char.append((list("".join(res[0])), list("".join(res[1])))) with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results_tmp) + wer = write_error_stats(f, f"{test_set_name}-{key}", results_char) test_set_wers[key] = wer logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" with open(errs_info, "w") as f: - print("settings\tWER", file=f) + print("settings\tCER", file=f) for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: s += "{}\t{}{}\n".format(key, val, note) diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py new file mode 100644 index 000000000..8421dd3ea --- /dev/null +++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from model import TdnnLstm +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import ( + get_lattice, + one_best_decoding, +) +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Use the best path as decoding output. Only the transformer encoder + output is used for decoding. We call it HLG decoding. + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "subsampling_factor": 3, + "num_classes": 220, + "sample_rate": 16000, + "search_beam": 20, + "output_beam": 7, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = TdnnLstm( + num_features=params.feature_dim, + num_classes=params.num_classes, + subsampling_factor=params.subsampling_factor, + ) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + model.to(device) + model.eval() + + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + features = features.permute(0, 2, 1) # now features is [N, C, T] + + with torch.no_grad(): + nnet_output = model(features) + # nnet_output is [N, T, C] + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + HLG=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + assert(params.method == "1best") + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main()