add timit recipe for icefall

This commit is contained in:
Mingshuang Luo 2021-11-08 21:12:41 +08:00
parent 91cfecebf2
commit 3f1930e1a1
15 changed files with 4095 additions and 1 deletions

View File

@ -10,8 +10,10 @@ We may add recipes for other tasks as well in the future.
.. Other recipes are listed in a alphabetical order. .. Other recipes are listed in a alphabetical order.
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 3
yesno yesno
librispeech librispeech
timit

View File

@ -0,0 +1,10 @@
TIMIT
===========
We provide the following models for the TIMIT dataset:
.. toctree::
:maxdepth: 2
timit/tdnn_lstm_ctc
timit/tdnn_ligru_ctc

View File

@ -0,0 +1,392 @@
TDNN-LiGRU-CTC
=============
This tutorial shows you how to run a TDNN-LiGRU-CTC model with the `TIMIT <https://data.deepai.org/timit.zip>`_ dataset.
.. HINT::
We assume you have read the page :ref:`install icefall` and have setup
the environment for ``icefall``.
Data preparation
----------------
.. code-block:: bash
$ cd egs/timit/ASR
$ ./prepare.sh
The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
All you need to do is to run it.
The data preparation contains several stages, you can use the following two
options:
- ``--stage``
- ``--stop-stage``
to control which stage(s) should be run. By default, all stages are executed.
For example,
.. code-block:: bash
$ cd egs/timit/ASR
$ ./prepare.sh --stage 0 --stop-stage 0
means to run only stage 0.
To run stage 2 to stage 5, use:
.. code-block:: bash
$ ./prepare.sh --stage 2 --stop-stage 5
Training
--------
Now describing the training of TDNN-LiGRU-CTC model, contained in
the `tdnn_ligru_ctc <https://github.com/k2-fsa/icefall/tree/master/egs/timit/ASR/tdnn_ligru_ctc>`_
folder.
.. HINT::
TIMIT is very small dataset. So one GPU is enough.
The command to run the training part is:
.. code-block:: bash
$ cd egs/timit/ASR
$ export CUDA_VISIBLE_DEVICES="0"
$ ./tdnn_ligru_ctc/train.py
By default, it will run ``30`` epochs. Training logs and checkpoints are saved
in ``tdnn_ligru_ctc/exp``.
In ``tdnn_ligru_ctc/exp``, you will find the following files:
- ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt``
These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
.. code-block:: bash
$ ./tdnn_ligru_ctc/train.py --start-epoch 11
- ``tensorboard/``
This folder contains TensorBoard logs. Training loss, validation loss, learning
rate, etc, are recorded in these logs. You can visualize them by:
.. code-block:: bash
$ cd tdnn_ligru_ctc/exp/tensorboard
$ tensorboard dev upload --logdir . --description "TDNN ligru training for timit with icefall"
- ``log/log-train-xxxx``
It is the detailed training log in text format, same as the one
you saw printed to the console during training.
To see available training options, you can use:
.. code-block:: bash
$ ./tdnn_ligru_ctc/train.py --help
Other training options, e.g., learning rate, results dir, etc., are
pre-configured in the function ``get_params()``
in `tdnn_ligru_ctc/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/timit/ASR/tdnn_ligru_ctc/train.py>`_.
Normally, you don't need to change them. You can change them by modifying the code, if
you want.
Decoding
--------
The decoding part uses checkpoints saved by the training part, so you have
to run the training part first.
The command for decoding is:
.. code-block:: bash
$ export CUDA_VISIBLE_DEVICES="0"
$ ./tdnn_ligru_ctc/decode.py
You will see the WER in the output log.
Decoded results are saved in ``tdnn_ligru_ctc/exp``.
.. code-block:: bash
$ ./tdnn_ligru_ctc/decode.py --help
shows you the available decoding options.
Some commonly used options are:
- ``--epoch``
You can select which checkpoint to be used for decoding.
For instance, ``./tdnn_ligru_ctc/decode.py --epoch 10`` means to use
``./tdnn_ligru_ctc/exp/epoch-10.pt`` for decoding.
- ``--avg``
It's related to model averaging. It specifies number of checkpoints
to be averaged. The averaged model is used for decoding.
For example, the following command:
.. code-block:: bash
$ ./tdnn_ligru_ctc/decode.py --epoch 25 --avg 17
uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``,
``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``,
``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``,
``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``,
``epoch-24.pt`` and ``epoch-25.pt``
for decoding.
- ``--export``
If it is ``True``, i.e., ``./tdnn_ligru_ctc/decode.py --export 1``, the code
will save the averaged model to ``tdnn_ligru_ctc/exp/pretrained.pt``.
See :ref:`tdnn_ligru_ctc use a pre-trained model` for how to use it.
.. _tdnn_ligru_ctc use a pre-trained model:
Pre-trained Model
-----------------
We have uploaded the pre-trained model to
`<https://huggingface.co/luomingshuang/icefall_asr_timit_tdnn_ligru_ctc>`_.
The following shows you how to use the pre-trained model.
Install kaldifeat
~~~~~~~~~~~~~~~~~
`kaldifeat <https://github.com/csukuangfj/kaldifeat>`_ is used to
extract features for a single sound file or multiple sound files
at the same time.
Please refer to `<https://github.com/csukuangfj/kaldifeat>`_ for installation.
Download the pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/timit/ASR
$ mkdir tmp-ligru
$ cd tmp-ligru
$ git lfs install
$ git clone https://huggingface.co/luomingshuang/icefall_asr_timit_tdnn_ligru_ctc
.. CAUTION::
You have to use ``git lfs`` to download the pre-trained model.
.. CAUTION::
In order to use this pre-trained model, your k2 version has to be v1.7 or later.
After downloading, you will have the following files:
.. code-block:: bash
$ cd egs/timit/ASR
$ tree tmp-ligru
.. code-block:: bash
tmp-ligru/
`-- icefall_asr_timit_tdnn_ligru_ctc
|-- README.md
|-- data
| |-- lang_phone
| | |-- HLG.pt
| | |-- tokens.txt
| | `-- words.txt
| `-- lm
| `-- G_4_gram.pt
|-- exp
| `-- pretrained.pt
`-- test_wavs
|-- FDHC0_SI1559.WAV
|-- FELC0_SI756.WAV
|-- FMGD0_SI1564.WAV
`-- trans.txt
6 directories, 10 files
**File descriptions**:
- ``data/lang_phone/HLG.pt``
It is the decoding graph.
- ``data/lang_phone/tokens.txt``
It contains tokens and their IDs.
- ``data/lang_phone/words.txt``
It contains words and their IDs.
- ``data/lm/G_4_gram.pt``
It is a 4-gram LM, useful for LM rescoring.
- ``exp/pretrained.pt``
It contains pre-trained model parameters, obtained by averaging
checkpoints from ``epoch-9.pt`` to ``epoch-25.pt``.
Note: We have removed optimizer ``state_dict`` to reduce file size.
- ``test_waves/*.WAV``
It contains some test sound files from timit ``TEST`` dataset.
- ``test_waves/trans.txt``
It contains the reference transcripts for the sound files in ``test_waves/``.
The information of the test sound files is listed below:
.. code-block:: bash
$ ffprobe -show_format tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
Input #0, nistsphere, from 'tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV':
Metadata:
database_id : TIMIT
database_version: 1.0
utterance_id : dhc0_si1559
sample_min : -4176
sample_max : 5984
Duration: 00:00:03.40, bitrate: 258 kb/s
Stream #0:0: Audio: pcm_s16le, 16000 Hz, 1 channels, s16, 256 kb/s
$ ffprobe -show_format tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
Input #0, nistsphere, from 'tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV':
Metadata:
database_id : TIMIT
database_version: 1.0
utterance_id : elc0_si756
sample_min : -1546
sample_max : 1989
Duration: 00:00:04.19, bitrate: 257 kb/s
Stream #0:0: Audio: pcm_s16le, 16000 Hz, 1 channels, s16, 256 kb/s
$ ffprobe -show_format tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
Input #0, nistsphere, from 'tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV':
Metadata:
database_id : TIMIT
database_version: 1.0
utterance_id : mgd0_si1564
sample_min : -7626
sample_max : 10573
Duration: 00:00:04.44, bitrate: 257 kb/s
Stream #0:0: Audio: pcm_s16le, 16000 Hz, 1 channels, s16, 256 kb/s
Inference with a pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/timit/ASR
$ ./tdnn_ligru_ctc/pretrained.py --help
shows the usage information of ``./tdnn_ligru_ctc/pretrained.py``.
To decode with ``1best`` method, we can use:
.. code-block:: bash
./tdnn_ligru_ctc/pretrained.py
--method 1best
--checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_16_25.pt
--words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt
--HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
The output is:
.. code-block::
2021-11-08 20:41:33,660 INFO [pretrained.py:169] device: cuda:0
2021-11-08 20:41:33,660 INFO [pretrained.py:171] Creating model
2021-11-08 20:41:38,680 INFO [pretrained.py:183] Loading HLG from ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt
2021-11-08 20:41:38,695 INFO [pretrained.py:200] Constructing Fbank computer
2021-11-08 20:41:38,697 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 20:41:38,704 INFO [pretrained.py:216] Decoding started
2021-11-08 20:41:39,819 INFO [pretrained.py:246] Use HLG decoding
2021-11-08 20:41:39,829 INFO [pretrained.py:267]
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx ih w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV:
sil m ih sil t ih r iy s sil s er r ih m ih sil m aa l ih sil k l ey sil r eh sil d w ay sil d aa r sil b ah f sil jh
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV:
sil hh ah z sil b ih sil g r iy w ah z sil d aw n ih sil b ay s sil n ey sil w eh l f eh n s ih z eh n dh eh r w er sil g r ey z ih ng sil k ae dx l sil
2021-11-08 20:41:39,829 INFO [pretrained.py:269] Decoding Done
To decode with ``whole-lattice-rescoring`` methond, you can use
.. code-block:: bash
./tdnn_ligru_ctc/pretrained.py \
--method whole-lattice-rescoring \
--checkpoint ./tmp-ligru/icefall_asr_timit_tdnn-ligru_ctc/exp/pretraind.pt \
--words-file ./tmp-ligru/icefall_asr_timit_tdnn-ligru_ctc/data/lang_phone/words.txt \
--HLG ./tmp-ligru/icefall_asr_timit_tdnn-ligru_ctc/data/lang_phone/HLG.pt \
--G ./tmp-ligru/icefall_asr_timit_tdnn-ligru_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.1 \
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
The decoding output is:
.. code-block::
2021-11-08 20:37:50,693 INFO [pretrained.py:169] device: cuda:0
2021-11-08 20:37:50,693 INFO [pretrained.py:171] Creating model
2021-11-08 20:37:54,693 INFO [pretrained.py:183] Loading HLG from ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt
2021-11-08 20:37:54,705 INFO [pretrained.py:191] Loading G from ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lm/G_4_gram.pt
2021-11-08 20:37:54,714 INFO [pretrained.py:200] Constructing Fbank computer
2021-11-08 20:37:54,715 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 20:37:54,720 INFO [pretrained.py:216] Decoding started
2021-11-08 20:37:55,808 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
2021-11-08 20:37:56,348 INFO [pretrained.py:267]
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw ah l iy v iy z ah sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy ng w ih th ih n ih m s eh l f sil jh
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV:
sil m ih sil t ih r iy l s sil s er r eh m ih sil m aa l ih ng sil k l ey sil r eh sil d w ay sil d aa r sil b ah f sil jh ch
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV:
sil hh ah z sil b ih n sil g r iy w ah z sil b aw n ih sil b ay s sil n ey sil w er l f eh n s ih z eh n dh eh r w er sil g r ey z ih ng sil k ae dx l sil
2021-11-08 20:37:56,348 INFO [pretrained.py:269] Decoding Done

View File

@ -0,0 +1,390 @@
TDNN-LSTM-CTC
=============
This tutorial shows you how to run a TDNN-LSTM-CTC model with the `TIMIT <https://data.deepai.org/timit.zip>`_ dataset.
.. HINT::
We assume you have read the page :ref:`install icefall` and have setup
the environment for ``icefall``.
Data preparation
----------------
.. code-block:: bash
$ cd egs/timit/ASR
$ ./prepare.sh
The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
All you need to do is to run it.
The data preparation contains several stages, you can use the following two
options:
- ``--stage``
- ``--stop-stage``
to control which stage(s) should be run. By default, all stages are executed.
For example,
.. code-block:: bash
$ cd egs/timit/ASR
$ ./prepare.sh --stage 0 --stop-stage 0
means to run only stage 0.
To run stage 2 to stage 5, use:
.. code-block:: bash
$ ./prepare.sh --stage 2 --stop-stage 5
Training
--------
Now describing the training of TDNN-LSTM-CTC model, contained in
the `tdnn_lstm_ctc <https://github.com/k2-fsa/icefall/tree/master/egs/timit/ASR/tdnn_lstm_ctc>`_
folder.
.. HINT::
TIMIT is very small dataset. So one GPU is enough.
The command to run the training part is:
.. code-block:: bash
$ cd egs/timit/ASR
$ export CUDA_VISIBLE_DEVICES="0"
$ ./tdnn_lstm_ctc/train.py
By default, it will run ``30`` epochs. Training logs and checkpoints are saved
in ``tdnn_lstm_ctc/exp``.
In ``tdnn_lstm_ctc/exp``, you will find the following files:
- ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt``
These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
.. code-block:: bash
$ ./tdnn_lstm_ctc/train.py --start-epoch 11
- ``tensorboard/``
This folder contains TensorBoard logs. Training loss, validation loss, learning
rate, etc, are recorded in these logs. You can visualize them by:
.. code-block:: bash
$ cd tdnn_lstm_ctc/exp/tensorboard
$ tensorboard dev upload --logdir . --description "TDNN LSTM training for timit with icefall"
- ``log/log-train-xxxx``
It is the detailed training log in text format, same as the one
you saw printed to the console during training.
To see available training options, you can use:
.. code-block:: bash
$ ./tdnn_lstm_ctc/train.py --help
Other training options, e.g., learning rate, results dir, etc., are
pre-configured in the function ``get_params()``
in `tdnn_lstm_ctc/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/timit/ASR/tdnn_lstm_ctc/train.py>`_.
Normally, you don't need to change them. You can change them by modifying the code, if
you want.
Decoding
--------
The decoding part uses checkpoints saved by the training part, so you have
to run the training part first.
The command for decoding is:
.. code-block:: bash
$ export CUDA_VISIBLE_DEVICES="0"
$ ./tdnn_lstm_ctc/decode.py
You will see the WER in the output log.
Decoded results are saved in ``tdnn_lstm_ctc/exp``.
.. code-block:: bash
$ ./tdnn_lstm_ctc/decode.py --help
shows you the available decoding options.
Some commonly used options are:
- ``--epoch``
You can select which checkpoint to be used for decoding.
For instance, ``./tdnn_lstm_ctc/decode.py --epoch 10`` means to use
``./tdnn_lstm_ctc/exp/epoch-10.pt`` for decoding.
- ``--avg``
It's related to model averaging. It specifies number of checkpoints
to be averaged. The averaged model is used for decoding.
For example, the following command:
.. code-block:: bash
$ ./tdnn_lstm_ctc/decode.py --epoch 25 --avg 10
uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``,
``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``,
``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt``
for decoding.
- ``--export``
If it is ``True``, i.e., ``./tdnn_lstm_ctc/decode.py --export 1``, the code
will save the averaged model to ``tdnn_lstm_ctc/exp/pretrained.pt``.
See :ref:`tdnn_lstm_ctc use a pre-trained model` for how to use it.
.. _tdnn_lstm_ctc use a pre-trained model:
Pre-trained Model
-----------------
We have uploaded the pre-trained model to
`<https://huggingface.co/luomingshuang/icefall_asr_timit_tdnn_lstm_ctc>`_.
The following shows you how to use the pre-trained model.
Install kaldifeat
~~~~~~~~~~~~~~~~~
`kaldifeat <https://github.com/csukuangfj/kaldifeat>`_ is used to
extract features for a single sound file or multiple sound files
at the same time.
Please refer to `<https://github.com/csukuangfj/kaldifeat>`_ for installation.
Download the pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/timit/ASR
$ mkdir tmp-lstm
$ cd tmp-lstm
$ git lfs install
$ git clone https://huggingface.co/luomingshuang/icefall_asr_timit_tdnn_lstm_ctc
.. CAUTION::
You have to use ``git lfs`` to download the pre-trained model.
.. CAUTION::
In order to use this pre-trained model, your k2 version has to be v1.7 or later.
After downloading, you will have the following files:
.. code-block:: bash
$ cd egs/timit/ASR
$ tree tmp-lstm
.. code-block:: bash
tmp-lstm/
`-- icefall_asr_timit_tdnn_lstm_ctc
|-- README.md
|-- data
| |-- lang_phone
| | |-- HLG.pt
| | |-- tokens.txt
| | `-- words.txt
| `-- lm
| `-- G_4_gram.pt
|-- exp
| `-- pretrained.pt
`-- test_wavs
|-- FDHC0_SI1559.WAV
|-- FELC0_SI756.WAV
|-- FMGD0_SI1564.WAV
`-- trans.txt
6 directories, 10 files
**File descriptions**:
- ``data/lang_phone/HLG.pt``
It is the decoding graph.
- ``data/lang_phone/tokens.txt``
It contains tokens and their IDs.
- ``data/lang_phone/words.txt``
It contains words and their IDs.
- ``data/lm/G_4_gram.pt``
It is a 4-gram LM, useful for LM rescoring.
- ``exp/pretrained.pt``
It contains pre-trained model parameters, obtained by averaging
checkpoints from ``epoch-16.pt`` to ``epoch-25.pt``.
Note: We have removed optimizer ``state_dict`` to reduce file size.
- ``test_waves/*.WAV``
It contains some test sound files from timit ``TEST`` dataset.
- ``test_waves/trans.txt``
It contains the reference transcripts for the sound files in ``test_waves/``.
The information of the test sound files is listed below:
.. code-block:: bash
$ ffprobe -show_format tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
Input #0, nistsphere, from 'tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV':
Metadata:
database_id : TIMIT
database_version: 1.0
utterance_id : dhc0_si1559
sample_min : -4176
sample_max : 5984
Duration: 00:00:03.40, bitrate: 258 kb/s
Stream #0:0: Audio: pcm_s16le, 16000 Hz, 1 channels, s16, 256 kb/s
$ ffprobe -show_format tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
Input #0, nistsphere, from 'tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV':
Metadata:
database_id : TIMIT
database_version: 1.0
utterance_id : elc0_si756
sample_min : -1546
sample_max : 1989
Duration: 00:00:04.19, bitrate: 257 kb/s
Stream #0:0: Audio: pcm_s16le, 16000 Hz, 1 channels, s16, 256 kb/s
$ ffprobe -show_format tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
Input #0, nistsphere, from 'tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV':
Metadata:
database_id : TIMIT
database_version: 1.0
utterance_id : mgd0_si1564
sample_min : -7626
sample_max : 10573
Duration: 00:00:04.44, bitrate: 257 kb/s
Stream #0:0: Audio: pcm_s16le, 16000 Hz, 1 channels, s16, 256 kb/s
Inference with a pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/timit/ASR
$ ./tdnn_lstm_ctc/pretrained.py --help
shows the usage information of ``./tdnn_lstm_ctc/pretrained.py``.
To decode with ``1best`` method, we can use:
.. code-block:: bash
./tdnn_lstm_ctc/pretrained.py
--method 1best
--checkpoint ./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt
--words-file ./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt
--HLG ./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
The output is:
.. code-block::
2021-11-08 21:02:49,583 INFO [pretrained.py:169] device: cuda:0
2021-11-08 21:02:49,584 INFO [pretrained.py:171] Creating model
2021-11-08 21:02:53,816 INFO [pretrained.py:183] Loading HLG from ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
2021-11-08 21:02:53,827 INFO [pretrained.py:200] Constructing Fbank computer
2021-11-08 21:02:53,827 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 21:02:53,831 INFO [pretrained.py:216] Decoding started
2021-11-08 21:02:54,380 INFO [pretrained.py:246] Use HLG decoding
2021-11-08 21:02:54,387 INFO [pretrained.py:267]
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy w ih th ih n ih m s eh l f sil jh
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV:
sil dh ih sil t ih r ih s sil s er r ih m ih sil m aa l ih ng sil k l ey sil r eh sil d w ay sil d aa r sil b ah f sil <UNK> jh
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV:
sil hh ae z sil b ih n iy w ah z sil b ae n ih sil b ay s sil n ey sil k eh l f eh n s ih z eh n dh eh r w er sil g r ey z ih ng sil k ae dx l sil
2021-11-08 21:02:54,387 INFO [pretrained.py:269] Decoding Done
To decode with ``whole-lattice-rescoring`` methond, you can use
.. code-block:: bash
./tdnn_lstm_ctc/pretrained.py \
--method whole-lattice-rescoring \
--checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretraind.pt \
--words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt \
--HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \
--G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.8 \
./tmp-lstm/icefall_asr_timit_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \
./tmp-lstm/icefall_asr_timit_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \
./tmp-lstm/icefall_asr_timit_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac
The decoding output is:
.. code-block::
2021-11-08 20:05:22,739 INFO [pretrained.py:169] device: cuda:0
2021-11-08 20:05:22,739 INFO [pretrained.py:171] Creating model
2021-11-08 20:05:26,959 INFO [pretrained.py:183] Loading HLG from ./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
2021-11-08 20:05:26,971 INFO [pretrained.py:191] Loading G from ./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt
2021-11-08 20:05:26,977 INFO [pretrained.py:200] Constructing Fbank computer
2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started
2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
2021-11-08 20:05:27,878 INFO [pretrained.py:267]
./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV:
sil dh ih sil t ih r iy ih s sil s er r eh m ih sil n ah l ih ng sil k l ey sil r eh sil d w ay sil d aa r sil b ow f sil jh
./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV:
sil hh ah z sil b ih n iy w ah z sil b ae n ih sil b ay s sil n ey sil k ih l f eh n s ih z eh n dh eh r w er sil g r ey z ih n sil k ae dx l sil
2021-11-08 20:05:27,878 INFO [pretrained.py:269] Decoding Done

View File

@ -0,0 +1,97 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the TIMIT dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_timit():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
num_mel_bins = 80
dataset_parts = (
"TRAIN",
"DEV",
"TEST",
)
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, output_dir=src_dir
)
assert manifests is not None
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.json.gz").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "TRAIN" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomHdf5Writer,
)
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_timit()

153
egs/timit/ASR/prepare.sh Normal file
View File

@ -0,0 +1,153 @@
#!/usr/bin/env bash
set -eou pipefail
num_phones=39
# Here we use num_phones=39 for modeling
nj=15
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/timit
# You can find data, train_data.csv, test_data.csv, etc, inside it.
# You can download them from https://data.deepai.org/timit.zip
#
# - $dl_dir/lm
# This directory contains the language model(LM) downloaded from
# https://huggingface.co/luomingshuang/timit_lm, and the LM is based
# on 39 phones. About how to get these LM files, you can know it
# from https://github.com/luomingshuang/Train_LM_with_kaldilm.
#
# - lm_3_gram_tgmed.arpa
# - lm_4_gram_tgmed.arpa
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
dl_dir=$PWD/download
splits_dir=$PWD/splits_dir
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: Download LM"
# We assume that you have installed the git-lfs, if not, you could install it
# using: `sudo apt-get install git-lfs && git-lfs install`
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
git clone https://huggingface.co/luomingshuang/timit_lm $dl_dir/lm
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/timit,
# you can create a symlink
#
# ln -sfv /path/to/timit $dl_dir/timit
#
if [ ! -d $dl_dir/timit ]; then
lhotse download timit $dl_dir
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare timit manifest"
# We assume that you have downloaded the timit corpus
# to $dl_dir/timit
mkdir -p data/manifests
lhotse prepare timit -p $num_phones -j $nj $dl_dir/timit/data data/manifests
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for timit"
mkdir -p data/fbank
./local/compute_fbank_timit.py
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
mkdir -p data/fbank
./local/compute_fbank_musan.py
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
lang_dir=data/lang_phone
mkdir -p $lang_dir
./local/prepare_lexicon.py \
--manifests-dir data/manifests \
--lang-dir $lang_dir
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$dl_dir/lm/lm_3_gram_tgmed.arpa > data/lm/G_3_gram.fst.txt
fi
if [ ! -f data/lm/G_4_gram.fst.txt ]; then
# It is used for LM rescoring
python3 -m kaldilm \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=4 \
$dl_dir/lm/lm_4_gram_tgmed.arpa > data/lm/G_4_gram.fst.txt
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Compile HLG"
./local/compile_hlg.py --lang-dir data/lang_phone
fi

View File

View File

@ -0,0 +1,335 @@
# Copyright 2021 Piotr Żelasko
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import List, Union
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
CutMix,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class TimitAsrDataModule(DataModule):
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
super().add_arguments(parser)
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--feature-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the BucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
def train_dataloaders(self) -> DataLoader:
logging.info("About to get train cuts")
cuts_train = self.train_cuts()
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
logging.info("About to create train dataset")
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = [
SpecAugment(
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
]
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return train_dl
def valid_dataloaders(self) -> DataLoader:
logging.info("About to get dev cuts")
cuts_valid = self.valid_cuts()
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = SingleCutSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
cuts = self.test_cuts()
is_list = isinstance(cuts, list)
test_loaders = []
if not is_list:
cuts = [cuts]
for cuts_test in cuts:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
)
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = SingleCutSampler(
cuts_test, max_duration=self.args.max_duration
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1
)
test_loaders.append(test_dl)
if is_list:
return test_loaders
else:
return test_loaders[0]
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(
self.args.feature_dir / "cuts_TRAIN.json.gz"
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(
self.args.feature_dir / "cuts_DEV.json.gz"
)
return cuts_valid
@lru_cache()
def test_cuts(self) -> CutSet:
logging.debug("About to get test cuts")
cuts_test = load_manifest(
self.args.feature_dir / "cuts_TEST.json.gz"
)
return cuts_test

View File

@ -0,0 +1,506 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import TimitAsrDataModule
from model import TdnnLstm
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import (
get_lattice,
nbest_decoding,
one_best_decoding,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=19,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=5,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--method",
type=str,
default="whole-lattice-rescoring",
help="""Decoding method.
Supported values are:
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("tdnn_ligru_ctc/exp/"),
"lang_dir": Path("data/lang_phone"),
"lm_dir": Path("data/lm"),
"feature_dim": 80,
"subsampling_factor": 2,
"search_beam": 20,
"output_beam": 5,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
batch: dict,
lexicon: Lexicon,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
HLG:
The decoding graph.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains word symbol table.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = HLG.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
feature = feature.permute(0, 2, 1) # now feature is (N, C, T)
nnet_output = model(feature)
# nnet_output is (N, T, C)
supervisions = batch["supervisions"]
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-{params.num_paths}"
hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"]
lm_scale_list = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09]
lm_scale_list += [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
else:
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
ans = dict()
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
lexicon: Lexicon,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph.
lexicon:
It contains word symbol table.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
HLG=HLG,
batch=batch,
lexicon=lexicon,
G=G,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out PERs, per-phone error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"per-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tPER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, PER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
TimitAsrDataModule.add_arguments(parser)
args = parser.parse_args()
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_phone_id = max(lexicon.tokens)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]:
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
model = TdnnLstm(
num_features=params.feature_dim,
num_classes=max_phone_id + 1, # +1 for the blank symbol
subsampling_factor=params.subsampling_factor,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
#load_checkpoint(f"tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretrained.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device)
model.eval()
timit = TimitAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
#test_sets = ["test-clean", "test-other"]
#test_sets = ["test-other"]
#for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
#if test_set == "test-clean": continue
#if test_set == "test-other": break
test_set = "TEST"
test_dl = timit.test_dataloaders()
#test_set = "TRAIN"
#test_dl = timit.train_dataloaders()
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
lexicon=lexicon,
G=G,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,567 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional
class TdnnLstm(nn.Module):
def __init__(
self,
num_features: int,
num_classes: int,
subsampling_factor: int = 3
) -> None:
"""
Args:
num_features:
The input dimension of the model.
num_classes:
The output dimension of the model.
subsampling_factor:
It reduces the number of output frames by this factor.
"""
super().__init__()
self.num_features = num_features
self.num_classes = num_classes
self.subsampling_factor = subsampling_factor
self.tdnn = nn.Sequential(
nn.Conv1d(
in_channels=num_features,
out_channels=512,
kernel_size=3,
stride=1,
padding=1,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=512, affine=False),
nn.Conv1d(
in_channels=512,
out_channels=512,
kernel_size=3,
stride=1,
padding=1,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=512, affine=False),
nn.Conv1d(
in_channels=512,
out_channels=512,
kernel_size=3,
#stride=self.subsampling_factor, # stride: subsampling_factor!
stride=1,
padding=1,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=512, affine=False),
nn.Conv1d(
in_channels=512,
out_channels=512,
kernel_size=3,
stride=self.subsampling_factor,
padding=1,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=512, affine=False),
)
self.ligrus = nn.ModuleList(
[
LiGRU(input_shape=[None, None, 512], hidden_size=512, num_layers=1, bidirectional=True, re_init=False)
for _ in range(4)
]
)
self.linears = nn.ModuleList(
[
nn.Linear(in_features=1024, out_features=512)
for _ in range(4)
]
)
self.bnorms = nn.ModuleList(
[
nn.BatchNorm1d(num_features=512, affine=False)
for _ in range(4)
]
)
self.dropout = nn.Dropout(0.2)
self.linear = nn.Linear(in_features=512, out_features=self.num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
Its shape is [N, C, T]
Returns:
The output tensor has shape [N, T, C]
"""
x = self.tdnn(x)
x = x.permute(0, 2, 1)
for ligru, linear, bnorm in zip(self.ligrus, self.linears, self.bnorms):
x_new, _ = ligru(x)
#print('ligru output shape: ', x_new.shape)
x_new = linear(x_new)
#print('linear output shape: ', x_new.shape)
x_new = bnorm(x_new.permute(0, 2, 1)).permute(0, 2, 1)
# 2, 0, 1
#) # (T, N, C) -> (N, C, T) -> (T, N, C)
x_new = self.dropout(x_new)
x = x_new + x # skip connections
#x = x.transpose(
# 1, 0
#) # (T, N, C) -> (N, T, C) -> linear expects "features" in the last dim
x = self.linear(x)
x = nn.functional.log_softmax(x, dim=-1)
return x
class LiGRU(torch.nn.Module):
""" This function implements a Light GRU (liGRU).
LiGRU is single-gate GRU model based on batch-norm + relu
activations + recurrent dropout. For more info see:
"M. Ravanelli, P. Brakel, M. Omologo, Y. Bengio,
Light Gated Recurrent Units for Speech Recognition,
in IEEE Transactions on Emerging Topics in Computational Intelligence,
2018" (https://arxiv.org/abs/1803.10225)
This is a custm RNN and to speed it up it must be compiled with
the torch just-in-time compiler (jit) right before using it.
You can compile it with:
compiled_model = torch.jit.script(model)
It accepts in input tensors formatted as (batch, time, fea).
In the case of 4d inputs like (batch, time, fea, channel) the tensor is
flattened as (batch, time, fea*channel).
Arguments
---------
hidden_size : int
Number of output neurons (i.e, the dimensionality of the output).
values (i.e, time and frequency kernel sizes respectively).
input_shape : tuple
The shape of an example input.
nonlinearity : str
Type of nonlinearity (tanh, relu).
normalization : str
Type of normalization for the ligru model (batchnorm, layernorm).
Every string different from batchnorm and layernorm will result
in no normalization.
num_layers : int
Number of layers to employ in the RNN architecture.
bias : bool
If True, the additive bias b is adopted.
dropout : float
It is the dropout factor (must be between 0 and 1).
re_init : bool
If True, orthogonal initialization is used for the recurrent weights.
Xavier initialization is used for the input connection weights.
bidirectional : bool
If True, a bidirectional model that scans the sequence both
right-to-left and left-to-right is used.
Example
-------
>>> inp_tensor = torch.rand([4, 10, 20])
>>> net = LiGRU(input_shape=inp_tensor.shape, hidden_size=5)
>>> out_tensor, _ = net(inp_tensor)
>>>
torch.Size([4, 10, 5])
"""
def __init__(
self,
hidden_size,
input_shape,
nonlinearity="relu",
normalization="batchnorm",
num_layers=1,
bias=True,
dropout=0.0,
re_init=True,
bidirectional=False,
):
super().__init__()
self.hidden_size = hidden_size
self.nonlinearity = nonlinearity
self.num_layers = num_layers
self.normalization = normalization
self.bias = bias
self.dropout = dropout
self.re_init = re_init
self.bidirectional = bidirectional
self.reshape = False
# Computing the feature dimensionality
if len(input_shape) > 3:
self.reshape = True
self.fea_dim = float(torch.prod(torch.tensor(input_shape[2:])))
self.batch_size = input_shape[0]
self.rnn = self._init_layers()
if self.re_init:
rnn_init(self.rnn)
def _init_layers(self):
"""Initializes the layers of the liGRU."""
rnn = torch.nn.ModuleList([])
#print('fea_dim: ', self.fea_dim)
current_dim = self.fea_dim
for i in range(self.num_layers):
rnn_lay = LiGRU_Layer(
current_dim,
self.hidden_size,
self.num_layers,
self.batch_size,
dropout=self.dropout,
nonlinearity=self.nonlinearity,
normalization=self.normalization,
bidirectional=self.bidirectional,
)
rnn.append(rnn_lay)
if self.bidirectional:
current_dim = self.hidden_size * 2
else:
current_dim = self.hidden_size
return rnn
def forward(self, x, hx: Optional[Tensor] = None):
"""Returns the output of the liGRU.
Arguments
---------
x : torch.Tensor
The input tensor.
hx : torch.Tensor
Starting hidden state.
"""
# Reshaping input tensors for 4d inputs
if self.reshape:
if x.ndim == 4:
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
# run ligru
output, hh = self._forward_ligru(x, hx=hx)
return output, hh
def _forward_ligru(self, x, hx: Optional[Tensor]):
"""Returns the output of the vanilla liGRU.
Arguments
---------
x : torch.Tensor
Input tensor.
hx : torch.Tensor
"""
h = []
if hx is not None:
if self.bidirectional:
hx = hx.reshape(
self.num_layers, self.batch_size * 2, self.hidden_size
)
# Processing the different layers
for i, ligru_lay in enumerate(self.rnn):
if hx is not None:
x = ligru_lay(x, hx=hx[i])
else:
x = ligru_lay(x, hx=None)
h.append(x[:, -1, :])
h = torch.stack(h, dim=1)
if self.bidirectional:
h = h.reshape(h.shape[1] * 2, h.shape[0], self.hidden_size)
else:
h = h.transpose(0, 1)
return x, h
class LiGRU_Layer(torch.nn.Module):
""" This function implements Light-Gated Recurrent Units (ligru) layer.
Arguments
---------
input_size : int
Feature dimensionality of the input tensors.
batch_size : int
Batch size of the input tensors.
hidden_size : int
Number of output neurons.
num_layers : int
Number of layers to employ in the RNN architecture.
nonlinearity : str
Type of nonlinearity (tanh, relu).
normalization : str
Type of normalization (batchnorm, layernorm).
Every string different from batchnorm and layernorm will result
in no normalization.
dropout : float
It is the dropout factor (must be between 0 and 1).
bidirectional : bool
if True, a bidirectional model that scans the sequence both
right-to-left and left-to-right is used.
"""
def __init__(
self,
input_size,
hidden_size,
num_layers,
batch_size,
dropout=0.0,
nonlinearity="relu",
normalization="batchnorm",
bidirectional=False,
):
super(LiGRU_Layer, self).__init__()
self.hidden_size = int(hidden_size)
self.input_size = int(input_size)
self.batch_size = batch_size
self.bidirectional = bidirectional
self.dropout = dropout
self.drop = torch.nn.Dropout(p=self.dropout, inplace=False)
self.N_drop_masks = 16000
self.drop_mask_cnt = 0
self.drop_mask_te = torch.tensor([1.0]).float()
self.w = nn.Linear(self.input_size, 2 * self.hidden_size, bias=False)
self.u = nn.Linear(self.hidden_size, 2 * self.hidden_size, bias=False)
print(self.batch_size)
#if self.bidirectional:
# self.batch_size = self.batch_size * 2
# Initializing batch norm
self.normalize = False
if normalization == "batchnorm":
self.norm = nn.BatchNorm1d(2 * self.hidden_size, momentum=0.05)
self.normalize = True
elif normalization == "layernorm":
self.norm = torch.nn.LayerNorm(2 * self.hidden_size)
self.normalize = True
else:
# Normalization is disabled here. self.norm is only formally
# initialized to avoid jit issues.
self.norm = torch.nn.LayerNorm(2 * self.hidden_size)
self.normalize = True
# Initial state
self.register_buffer("h_init", torch.zeros(1, self.hidden_size))
# Preloading dropout masks (gives some speed improvement)
#self._init_drop(self.batch_size)
# Setting the activation function
if nonlinearity == "tanh":
self.act = torch.nn.Tanh()
elif nonlinearity == "sin":
self.act = torch.sin
elif nonlinearity == "leaky_relu":
self.act = torch.nn.LeakyReLU()
else:
self.act = torch.nn.ReLU()
def forward(self, x, hx: Optional[Tensor] = None):
# type: (Tensor, Optional[Tensor]) -> Tensor # noqa F821
"""Returns the output of the liGRU layer.
Arguments
---------
x : torch.Tensor
Input tensor.
"""
if self.bidirectional:
x_flip = x.flip(1)
x = torch.cat([x, x_flip], dim=0)
# Change batch size if needed
self._change_batch_size(x)
# Feed-forward affine transformations (all steps in parallel)
#print(x.shape)
w = self.w(x)
# Apply batch normalization
if self.normalize:
w_bn = self.norm(w.reshape(w.shape[0] * w.shape[1], w.shape[2]))
w = w_bn.reshape(w.shape[0], w.shape[1], w.shape[2])
# Processing time steps
if hx is not None:
h = self._ligru_cell(w, hx)
else:
h = self._ligru_cell(w, self.h_init)
if self.bidirectional:
h_f, h_b = h.chunk(2, dim=0)
h_b = h_b.flip(1)
h = torch.cat([h_f, h_b], dim=2)
return h
def _ligru_cell(self, w, ht):
"""Returns the hidden states for each time step.
Arguments
---------
wx : torch.Tensor
Linearly transformed input.
"""
hiddens = []
# Sampling dropout mask
drop_mask = self._sample_drop_mask(w)
# Loop over time axis
for k in range(w.shape[1]):
gates = w[:, k] + self.u(ht)
at, zt = gates.chunk(2, 1)
zt = torch.sigmoid(zt)
hcand = self.act(at) * drop_mask
ht = zt * ht + (1 - zt) * hcand
hiddens.append(ht)
# Stacking hidden states
h = torch.stack(hiddens, dim=1)
return h
def _init_drop(self, batch_size):
"""Initializes the recurrent dropout operation. To speed it up,
the dropout masks are sampled in advance.
"""
#self.drop = torch.nn.Dropout(p=self.dropout, inplace=False)
self.N_drop_masks = 16000
self.drop_mask_cnt = 0
self.register_buffer(
"drop_masks",
self.drop(torch.ones(self.N_drop_masks, self.hidden_size)).data,
)
self.register_buffer("drop_mask_te", torch.tensor([1.0]).float())
def _sample_drop_mask(self, w):
"""Selects one of the pre-defined dropout masks"""
if self.training:
# Sample new masks when needed
if self.drop_mask_cnt + self.batch_size > self.N_drop_masks:
self.drop_mask_cnt = 0
self.drop_masks = self.drop(
torch.ones(
self.N_drop_masks, self.hidden_size, device=w.device
)
).data
# Sampling the mask
drop_mask = self.drop_masks[
self.drop_mask_cnt : self.drop_mask_cnt + self.batch_size
]
self.drop_mask_cnt = self.drop_mask_cnt + self.batch_size
else:
self.drop_mask_te = self.drop_mask_te.to(w.device)
drop_mask = self.drop_mask_te
return drop_mask
def _change_batch_size(self, x):
"""This function changes the batch size when it is different from
the one detected in the initialization method. This might happen in
the case of multi-gpu or when we have different batch sizes in train
and test. We also update the h_int and drop masks.
"""
if self.batch_size != x.shape[0]:
self.batch_size = x.shape[0]
if self.training:
self.drop_masks = self.drop(
torch.ones(
self.N_drop_masks, self.hidden_size, device=x.device,
)
).data
class Linear(torch.nn.Module):
"""Computes a linear transformation y = wx + b.
Arguments
---------
n_neurons : int
It is the number of output neurons (i.e, the dimensionality of the
output).
input_shape: tuple
It is the shape of the input tensor.
input_size: int
Size of the input tensor.
bias : bool
If True, the additive bias b is adopted.
combine_dims : bool
If True and the input is 4D, combine 3rd and 4th dimensions of input.
Example
-------
>>> inputs = torch.rand(10, 50, 40)
>>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
>>> output = lin_t(inputs)
>>> output.shape
torch.Size([10, 50, 100])
"""
def __init__(
self,
n_neurons,
input_shape=None,
input_size=None,
bias=True,
combine_dims=False,
):
super().__init__()
self.combine_dims = combine_dims
if input_shape is None and input_size is None:
raise ValueError("Expected one of input_shape or input_size")
if input_size is None:
input_size = input_shape[-1]
if len(input_shape) == 4 and self.combine_dims:
input_size = input_shape[2] * input_shape[3]
# Weights are initialized following pytorch approach
self.w = nn.Linear(input_size, n_neurons, bias=bias)
def forward(self, x):
"""Returns the linear transformation of input tensor.
Arguments
---------
x : torch.Tensor
Input to transform linearly.
"""
if x.ndim == 4 and self.combine_dims:
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
#print(x.shape)
#print(self.w)
wx = self.w(x)
return wx

View File

@ -0,0 +1,278 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from model import TdnnLstm
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_whole_lattice,
)
from icefall.utils import AttributeDict, get_env_info, get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument(
"--method",
type=str,
default="1best",
help="""Decoding method.
Possible values are:
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
(2) whole-lattice-rescoring - Use an LM to rescore the
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + n-gram LM rescoring.
""",
)
parser.add_argument(
"--G",
type=str,
help="""An LM for rescoring.
Used only when method is
whole-lattice-rescoring.
It's usually a 4-gram LM.
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.8,
help="""
Used only when method is whole-lattice-rescoring.
It specifies the scale for n-gram LM scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"subsampling_factor": 2,
"num_classes": 41,
"sample_rate": 16000,
"search_beam": 20,
"output_beam": 5,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = TdnnLstm(
num_features=params.feature_dim,
num_classes=params.num_classes,
subsampling_factor=params.subsampling_factor,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method == "whole-lattice-rescoring":
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G.lm_scores = G.scores.clone()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = features.permute(0, 2, 1) # now features is (N, C, T)
with torch.no_grad():
nnet_output = model(features)
# nnet_output is (N, T, C)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
elif params.method == "whole-lattice-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=[params.ngram_lm_scale],
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,595 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from asr_datamodule import TimitAsrDataModule
from lhotse.utils import fix_random_seed
from model import TdnnLstm
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions,
get_env_info,
setup_logger,
str2bool,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=60,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt
""",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
is saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0
- beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss
- use_double_scores: It is used in k2.ctc_loss
"""
params = AttributeDict(
{
"exp_dir": Path("tdnn_ligru_ctc/exp"),
"lang_dir": Path("data/lang_phone"),
"lr": 1e-3,
"feature_dim": 80,
"weight_decay": 5e-4,
"subsampling_factor": 2,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
"reset_interval": 200,
"valid_interval": 1000,
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
graph_compiler: CtcTrainingGraphCompiler,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of TdnnLstm in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
graph_compiler:
It is used to build a decoding graph from a ctc topo and training
transcript. The training transcript is contained in the given `batch`,
while the ctc topo is built when this compiler is instantiated.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = graph_compiler.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
feature = feature.permute(0, 2, 1) # now feature is (N, C, T)
assert feature.ndim == 3
feature = feature.to(device)
with torch.set_grad_enabled(is_training):
nnet_output = model(feature)
# nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervisions = batch["supervisions"]
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
decoding_graph = graph_compiler.compile(texts)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item()
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
graph_compiler: CtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
graph_compiler:
It is used to convert transcripts to FSAs.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer,
"train/valid_",
params.batch_idx_train,
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
lexicon = Lexicon(params.lang_dir)
max_phone_id = max(lexicon.tokens)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
model = TdnnLstm(
num_features=params.feature_dim,
num_classes=max_phone_id + 1, # +1 for the blank symbol
subsampling_factor=params.subsampling_factor,
)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
optimizer = optim.AdamW(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,
)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
scheduler.load_state_dict(checkpoints["scheduler"])
timit = TimitAsrDataModule(args)
train_dl = timit.train_dataloaders()
valid_dl = timit.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
if epoch > params.start_epoch:
logging.info(f"epoch {epoch}, lr: {scheduler.get_last_lr()[0]}")
if tb_writer is not None:
tb_writer.add_scalar(
"train/lr",
scheduler.get_last_lr()[0],
params.batch_idx_train,
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
scheduler.step()
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
TimitAsrDataModule.add_arguments(parser)
args = parser.parse_args()
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,490 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import TimitAsrDataModule
from model import TdnnLstm
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import (
get_lattice,
nbest_decoding,
one_best_decoding,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=25,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=5,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--method",
type=str,
default="whole-lattice-rescoring",
help="""Decoding method.
Supported values are:
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("tdnn_lstm_ctc/exp/"),
"lang_dir": Path("data/lang_phone"),
"lm_dir": Path("data/lm"),
"feature_dim": 80,
"subsampling_factor": 3,
"search_beam": 20,
"output_beam": 5,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
batch: dict,
lexicon: Lexicon,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
HLG:
The decoding graph.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains word symbol table.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = HLG.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
feature = feature.permute(0, 2, 1) # now feature is (N, C, T)
nnet_output = model(feature)
# nnet_output is (N, T, C)
supervisions = batch["supervisions"]
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-{params.num_paths}"
hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
else:
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
ans = dict()
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
lexicon: Lexicon,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph.
lexicon:
It contains word symbol table.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
HLG=HLG,
batch=batch,
lexicon=lexicon,
G=G,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out PERs, per-phone error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"per-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tPER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, PER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
TimitAsrDataModule.add_arguments(parser)
args = parser.parse_args()
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_phone_id = max(lexicon.tokens)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]:
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 20 seconds.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
model = TdnnLstm(
num_features=params.feature_dim,
num_classes=max_phone_id + 1, # +1 for the blank symbol
subsampling_factor=params.subsampling_factor,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device)
model.eval()
timit = TimitAsrDataModule(args)
test_dl = timit.test_dataloaders()
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
lexicon=lexicon,
G=G,
)
test_set = "TEST"
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,278 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from model import TdnnLstm
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_whole_lattice,
)
from icefall.utils import AttributeDict, get_env_info, get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument(
"--method",
type=str,
default="1best",
help="""Decoding method.
Possible values are:
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
(2) whole-lattice-rescoring - Use an LM to rescore the
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + n-gram LM rescoring.
""",
)
parser.add_argument(
"--G",
type=str,
help="""An LM for rescoring.
Used only when method is
whole-lattice-rescoring.
It's usually a 4-gram LM.
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.8,
help="""
Used only when method is whole-lattice-rescoring.
It specifies the scale for n-gram LM scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"subsampling_factor": 3,
"num_classes": 41,
"sample_rate": 16000,
"search_beam": 20,
"output_beam": 5,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = TdnnLstm(
num_features=params.feature_dim,
num_classes=params.num_classes,
subsampling_factor=params.subsampling_factor,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method == "whole-lattice-rescoring":
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G.lm_scores = G.scores.clone()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = features.permute(0, 2, 1) # now features is (N, C, T)
with torch.no_grad():
nnet_output = model(features)
# nnet_output is (N, T, C)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
elif params.method == "whole-lattice-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=[params.ngram_lm_scale],
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../icefall/shared/