mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-16 04:32:19 +00:00
Add more doc for the yesno recipe.
This commit is contained in:
parent
39554781b2
commit
49a2b4a9de
@ -3,16 +3,17 @@
|
|||||||
You can adapt this file completely to your liking, but it should at least
|
You can adapt this file completely to your liking, but it should at least
|
||||||
contain the root `toctree` directive.
|
contain the root `toctree` directive.
|
||||||
|
|
||||||
.. image:: _static/logo.png
|
|
||||||
:alt: icefall logo
|
|
||||||
:width: 100px
|
|
||||||
:align: center
|
|
||||||
:target: https://github.com/k2-fsa/icefall
|
|
||||||
|
|
||||||
icefall
|
icefall
|
||||||
=======
|
=======
|
||||||
|
|
||||||
Documentation for `icefall <https://github.com/k2-fsa/icefall>`, containing
|
.. image:: _static/logo.png
|
||||||
|
:alt: icefall logo
|
||||||
|
:width: 168px
|
||||||
|
:align: center
|
||||||
|
:target: https://github.com/k2-fsa/icefall
|
||||||
|
|
||||||
|
|
||||||
|
Documentation for `icefall <https://github.com/k2-fsa/icefall>`_, containing
|
||||||
speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
|
speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
@ -464,6 +464,6 @@ The decoding log is:
|
|||||||
2021-08-23 19:35:30,573 INFO [decode.py:236] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
|
2021-08-23 19:35:30,573 INFO [decode.py:236] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
|
||||||
2021-08-23 19:35:30,573 INFO [decode.py:299] Done!
|
2021-08-23 19:35:30,573 INFO [decode.py:299] Done!
|
||||||
|
|
||||||
Congratulations! You have successfully setup the environment and have run the first recipe in icefall.
|
**Congratulations!** You have successfully setup the environment and have run the first recipe in ``icefall``.
|
||||||
|
|
||||||
Have fun with icefall!
|
Have fun with ``icefall``!
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
Recipes
|
Recipes
|
||||||
=======
|
=======
|
||||||
|
|
||||||
This page contains various recipes in icefall.
|
This page contains various recipes in ``icefall``.
|
||||||
Currently, only speech recognition recipes are provided.
|
Currently, only speech recognition recipes are provided.
|
||||||
|
|
||||||
We may add recipes for other tasks in the future.
|
We may add recipes for other tasks as well in the future.
|
||||||
|
|
||||||
.. we put the yesno recipe as the first recipe since it is the simplest
|
.. we put the yesno recipe as the first recipe since it is the simplest one.
|
||||||
.. recipe.
|
.. Other recipes are listed in a alphabetical order.
|
||||||
.. Other recipes are sorted alphabetically
|
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
@ -1,7 +1,43 @@
|
|||||||
yesno
|
yesno
|
||||||
=====
|
=====
|
||||||
|
|
||||||
This page shows you how to run the ``yesno`` recipe.
|
This page shows you how to run the ``yesno`` recipe. It contains:
|
||||||
|
|
||||||
|
- (1) Prepare data for training
|
||||||
|
- (2) Train a TDNN model
|
||||||
|
|
||||||
|
- (a) View text format logs and visualize TensorBoard logs
|
||||||
|
- (b) Select device type, i.e., CPU and GPU, for training
|
||||||
|
- (c) Change training options
|
||||||
|
- (d) Resume training from a checkpoint
|
||||||
|
|
||||||
|
- (3) Decode with a trained model
|
||||||
|
|
||||||
|
- (a) Select a checkpoint for decoding
|
||||||
|
- (b) Model averaging
|
||||||
|
|
||||||
|
- (4) Colab notebook
|
||||||
|
|
||||||
|
- (a) It shows you step by step how to setup the environment, how to do training,
|
||||||
|
and how to do decoding
|
||||||
|
- (b) How to use a pre-trained model
|
||||||
|
|
||||||
|
- (5) Inference with a pre-trained model
|
||||||
|
|
||||||
|
- (a) Download a pre-trained model, provided by us
|
||||||
|
- (b) Decode a single sound file with a pre-trained model
|
||||||
|
- (c) Decode multiple sound files at the same time
|
||||||
|
|
||||||
|
It does **NOT** show you:
|
||||||
|
|
||||||
|
- (1) How to train with multiple GPUs
|
||||||
|
|
||||||
|
The ``yesno`` dataset is so small that CPU is more than enough
|
||||||
|
for training as well as for decoding.
|
||||||
|
|
||||||
|
- (2) How to use LM rescoring for decoding
|
||||||
|
|
||||||
|
The dataset does not have an LM for rescoring.
|
||||||
|
|
||||||
.. HINT::
|
.. HINT::
|
||||||
|
|
||||||
@ -11,8 +47,8 @@ This page shows you how to run the ``yesno`` recipe.
|
|||||||
.. HINT::
|
.. HINT::
|
||||||
|
|
||||||
You **don't** need a **GPU** to run this recipe. It can be run on a **CPU**.
|
You **don't** need a **GPU** to run this recipe. It can be run on a **CPU**.
|
||||||
The training time takes less than 30 **seconds** and you will get
|
The training part takes less than 30 **seconds** on a CPU and you will get
|
||||||
the following WER::
|
the following WER at the end::
|
||||||
|
|
||||||
[test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
[test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
||||||
|
|
||||||
@ -24,7 +60,7 @@ Data preparation
|
|||||||
$ cd egs/yesno/ASR
|
$ cd egs/yesno/ASR
|
||||||
$ ./prepare.sh
|
$ ./prepare.sh
|
||||||
|
|
||||||
The script ``./prepare.sh`` handles the data preparation for you, automagically.
|
The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
|
||||||
All you need to do is to run it.
|
All you need to do is to run it.
|
||||||
|
|
||||||
The data preparation contains several stages, you can use the following two
|
The data preparation contains several stages, you can use the following two
|
||||||
@ -74,7 +110,7 @@ In ``tdnn/exp``, you will find the following files:
|
|||||||
|
|
||||||
- ``epoch-0.pt``, ``epoch-1.pt``, ...
|
- ``epoch-0.pt``, ``epoch-1.pt``, ...
|
||||||
|
|
||||||
These are checkpoint files, containing model parameters and optimizer ``state_dict``.
|
These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``.
|
||||||
To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
|
To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
@ -123,11 +159,6 @@ In ``tdnn/exp``, you will find the following files:
|
|||||||
you saw printed to the console during training.
|
you saw printed to the console during training.
|
||||||
|
|
||||||
|
|
||||||
To see available training options, you can use:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
$ ./tdnn/train.py --help
|
|
||||||
|
|
||||||
.. NOTE::
|
.. NOTE::
|
||||||
|
|
||||||
@ -152,6 +183,18 @@ To see available training options, you can use:
|
|||||||
If you don't have GPUs, then you don't need to
|
If you don't have GPUs, then you don't need to
|
||||||
run ``export CUDA_VISIBLE_DEVICES=""``.
|
run ``export CUDA_VISIBLE_DEVICES=""``.
|
||||||
|
|
||||||
|
To see available training options, you can use:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ ./tdnn/train.py --help
|
||||||
|
|
||||||
|
Other training options, e.g., learning rate, results dir, etc., are
|
||||||
|
pre-configured in the function ``get_params()``
|
||||||
|
in `tdnn/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/yesno/ASR/tdnn/train.py>`_.
|
||||||
|
Normally, you don't need to change them. You can change them by modifying the code, if
|
||||||
|
you want.
|
||||||
|
|
||||||
Decoding
|
Decoding
|
||||||
--------
|
--------
|
||||||
|
|
||||||
@ -169,6 +212,225 @@ You will see the WER in the output log.
|
|||||||
|
|
||||||
Decoded results are saved in ``tdnn/exp``.
|
Decoded results are saved in ``tdnn/exp``.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ ./tdnn/decode.py --help
|
||||||
|
|
||||||
|
shows you the available decoding options.
|
||||||
|
|
||||||
|
Some commonly used options are:
|
||||||
|
|
||||||
|
- ``--epoch``
|
||||||
|
|
||||||
|
You can select which checkpoint to be used for decoding.
|
||||||
|
For instance, ``./tdnn/decode.py --epoch 10`` means to use
|
||||||
|
``./tdnn/exp/epoch-10.pt`` for decoding.
|
||||||
|
|
||||||
|
- ``--avg``
|
||||||
|
|
||||||
|
It's related to model averaging. It specifies number of checkpoints
|
||||||
|
to be averaged. The averaged model is used for decoding.
|
||||||
|
For example, the following command:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ ./tdnn/decode.py --epoch 10 --avg 3
|
||||||
|
|
||||||
|
uses the average of ``epoch-8.pt``, ``epoch-9.pt`` and ``epoch-10.pt``
|
||||||
|
for decoding.
|
||||||
|
|
||||||
|
- ``--export``
|
||||||
|
|
||||||
|
If it is ``True``, i.e., ``./tdnn/decode.py --export 1``, the code
|
||||||
|
will save the averaged model to ``tdnn/exp/pretrained.pt``.
|
||||||
|
See :ref:`yesno use a pre-trained model` for how to use it.
|
||||||
|
|
||||||
|
|
||||||
|
.. _yesno use a pre-trained model:
|
||||||
|
|
||||||
|
Pre-trained Model
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
We have uploaded the pre-trained model to
|
||||||
|
`<https://huggingface.co/csukuangfj/icefall_asr_yesno_tdnn>`_.
|
||||||
|
|
||||||
|
The following shows you how to use the pre-trained model.
|
||||||
|
|
||||||
|
Download the pre-trained model
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd egs/yesno/ASR
|
||||||
|
$ mkdir tmp
|
||||||
|
$ cd tmp
|
||||||
|
$ git lfs install
|
||||||
|
$ git clone https://huggingface.co/csukuangfj/icefall_asr_yesno_tdnn
|
||||||
|
|
||||||
|
.. CAUTION::
|
||||||
|
|
||||||
|
You have to use ``git lfs`` to download the pre-trained model.
|
||||||
|
|
||||||
|
After downloading, you will have the following files:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd egs/yesno/ASR
|
||||||
|
$ tree tmp
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
tmp/
|
||||||
|
`-- icefall_asr_yesno_tdnn
|
||||||
|
|-- README.md
|
||||||
|
|-- lang_phone
|
||||||
|
| |-- HLG.pt
|
||||||
|
| |-- L.pt
|
||||||
|
| |-- L_disambig.pt
|
||||||
|
| |-- Linv.pt
|
||||||
|
| |-- lexicon.txt
|
||||||
|
| |-- lexicon_disambig.txt
|
||||||
|
| |-- tokens.txt
|
||||||
|
| `-- words.txt
|
||||||
|
|-- lm
|
||||||
|
| |-- G.arpa
|
||||||
|
| `-- G.fst.txt
|
||||||
|
|-- pretrained.pt
|
||||||
|
`-- test_waves
|
||||||
|
|-- 0_0_0_1_0_0_0_1.wav
|
||||||
|
|-- 0_0_1_0_0_0_1_0.wav
|
||||||
|
|-- 0_0_1_0_0_1_1_1.wav
|
||||||
|
|-- 0_0_1_0_1_0_0_1.wav
|
||||||
|
|-- 0_0_1_1_0_0_0_1.wav
|
||||||
|
|-- 0_0_1_1_0_1_1_0.wav
|
||||||
|
|-- 0_0_1_1_1_0_0_0.wav
|
||||||
|
|-- 0_0_1_1_1_1_0_0.wav
|
||||||
|
|-- 0_1_0_0_0_1_0_0.wav
|
||||||
|
|-- 0_1_0_0_1_0_1_0.wav
|
||||||
|
|-- 0_1_0_1_0_0_0_0.wav
|
||||||
|
|-- 0_1_0_1_1_1_0_0.wav
|
||||||
|
|-- 0_1_1_0_0_1_1_1.wav
|
||||||
|
|-- 0_1_1_1_0_0_1_0.wav
|
||||||
|
|-- 0_1_1_1_1_0_1_0.wav
|
||||||
|
|-- 1_0_0_0_0_0_0_0.wav
|
||||||
|
|-- 1_0_0_0_0_0_1_1.wav
|
||||||
|
|-- 1_0_0_1_0_1_1_1.wav
|
||||||
|
|-- 1_0_1_1_0_1_1_1.wav
|
||||||
|
|-- 1_0_1_1_1_1_0_1.wav
|
||||||
|
|-- 1_1_0_0_0_1_1_1.wav
|
||||||
|
|-- 1_1_0_0_1_0_1_1.wav
|
||||||
|
|-- 1_1_0_1_0_1_0_0.wav
|
||||||
|
|-- 1_1_0_1_1_0_0_1.wav
|
||||||
|
|-- 1_1_0_1_1_1_1_0.wav
|
||||||
|
|-- 1_1_1_0_0_1_0_1.wav
|
||||||
|
|-- 1_1_1_0_1_0_1_0.wav
|
||||||
|
|-- 1_1_1_1_0_0_1_0.wav
|
||||||
|
|-- 1_1_1_1_1_0_0_0.wav
|
||||||
|
`-- 1_1_1_1_1_1_1_1.wav
|
||||||
|
|
||||||
|
4 directories, 42 files
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ soxi tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav
|
||||||
|
|
||||||
|
Input File : 'tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav'
|
||||||
|
Channels : 1
|
||||||
|
Sample Rate : 8000
|
||||||
|
Precision : 16-bit
|
||||||
|
Duration : 00:00:06.76 = 54080 samples ~ 507 CDDA sectors
|
||||||
|
File Size : 108k
|
||||||
|
Bit Rate : 128k
|
||||||
|
Sample Encoding: 16-bit Signed Integer PCM
|
||||||
|
|
||||||
|
- ``0_0_1_0_1_0_0_1.wav``
|
||||||
|
|
||||||
|
0 means No; 1 means Yes. No and Yes are not in English,
|
||||||
|
but in `Hebrew <https://en.wikipedia.org/wiki/Hebrew_language>`_.
|
||||||
|
So this file contains ``NO NO YES NO YES NO NO YES``.
|
||||||
|
|
||||||
|
Download kaldifeat
|
||||||
|
~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
`kaldifeat <https://github.com/csukuangfj/kaldifeat>`_ is used for extracting
|
||||||
|
features from a single or multiple sound files. Please refer to
|
||||||
|
`<https://github.com/csukuangfj/kaldifeat>`_ to install ``kaldifeat`` first.
|
||||||
|
|
||||||
|
Inference with a pre-trained model
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd egs/yesno/ASR
|
||||||
|
$ ./tdnn/pretrained.py --help
|
||||||
|
|
||||||
|
shows the usage information of ``./tdnn/pretrained.py``.
|
||||||
|
|
||||||
|
To decode a single file, we can use:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./tdnn/pretrained.py \
|
||||||
|
--checkpoint ./tmp/icefall_asr_yesno_tdnn/pretrained.pt \
|
||||||
|
--words-file ./tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt \
|
||||||
|
--HLG ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt \
|
||||||
|
./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav
|
||||||
|
|
||||||
|
The output is:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
2021-08-24 12:22:51,621 INFO [pretrained.py:119] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'checkpoint': './tmp/icefall_asr_yesno_tdnn/pretrained.pt', 'words_file': './tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt', 'HLG': './tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt', 'sound_files': ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav']}
|
||||||
|
2021-08-24 12:22:51,645 INFO [pretrained.py:125] device: cpu
|
||||||
|
2021-08-24 12:22:51,645 INFO [pretrained.py:127] Creating model
|
||||||
|
2021-08-24 12:22:51,650 INFO [pretrained.py:139] Loading HLG from ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt
|
||||||
|
2021-08-24 12:22:51,651 INFO [pretrained.py:143] Constructing Fbank computer
|
||||||
|
2021-08-24 12:22:51,652 INFO [pretrained.py:153] Reading sound files: ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav']
|
||||||
|
2021-08-24 12:22:51,684 INFO [pretrained.py:159] Decoding started
|
||||||
|
2021-08-24 12:22:51,708 INFO [pretrained.py:198]
|
||||||
|
./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav:
|
||||||
|
NO NO YES NO YES NO NO YES
|
||||||
|
|
||||||
|
|
||||||
|
2021-08-24 12:22:51,708 INFO [pretrained.py:200] Decoding Done
|
||||||
|
|
||||||
|
You can see that for the sound file ``0_0_1_0_1_0_0_1.wav``, the decoding result is
|
||||||
|
``NO NO YES NO YES NO NO YES``.
|
||||||
|
|
||||||
|
To decode **multiple** files at the same time, you can use
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./tdnn/pretrained.py \
|
||||||
|
--checkpoint ./tmp/icefall_asr_yesno_tdnn/pretrained.pt \
|
||||||
|
--words-file ./tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt \
|
||||||
|
--HLG ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt \
|
||||||
|
./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav \
|
||||||
|
./tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav
|
||||||
|
|
||||||
|
The decoding output is:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
2021-08-24 12:25:20,159 INFO [pretrained.py:119] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'checkpoint': './tmp/icefall_asr_yesno_tdnn/pretrained.pt', 'words_file': './tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt', 'HLG': './tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt', 'sound_files': ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav', './tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav']}
|
||||||
|
2021-08-24 12:25:20,181 INFO [pretrained.py:125] device: cpu
|
||||||
|
2021-08-24 12:25:20,181 INFO [pretrained.py:127] Creating model
|
||||||
|
2021-08-24 12:25:20,185 INFO [pretrained.py:139] Loading HLG from ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt
|
||||||
|
2021-08-24 12:25:20,186 INFO [pretrained.py:143] Constructing Fbank computer
|
||||||
|
2021-08-24 12:25:20,187 INFO [pretrained.py:153] Reading sound files: ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav',
|
||||||
|
'./tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav']
|
||||||
|
2021-08-24 12:25:20,213 INFO [pretrained.py:159] Decoding started
|
||||||
|
2021-08-24 12:25:20,287 INFO [pretrained.py:198]
|
||||||
|
./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav:
|
||||||
|
NO NO YES NO YES NO NO YES
|
||||||
|
|
||||||
|
./tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav:
|
||||||
|
YES NO YES YES NO YES YES YES
|
||||||
|
|
||||||
|
2021-08-24 12:25:20,287 INFO [pretrained.py:200] Decoding Done
|
||||||
|
|
||||||
|
You can see again that it decodes correctly.
|
||||||
|
|
||||||
Colab notebook
|
Colab notebook
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
@ -180,8 +442,4 @@ We do provide a colab notebook for this recipe.
|
|||||||
:target: https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing
|
:target: https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing
|
||||||
|
|
||||||
|
|
||||||
|
**Congratulations!** You have finished the simplest speech recognition recipe in ``icefall``.
|
||||||
Use a pre-trained model
|
|
||||||
-----------------------
|
|
||||||
|
|
||||||
TODO
|
|
||||||
|
@ -20,6 +20,7 @@ from icefall.utils import (
|
|||||||
get_texts,
|
get_texts,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -44,6 +45,17 @@ def get_parser():
|
|||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--export",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""When enabled, the averaged model is saved to
|
||||||
|
tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
||||||
|
pretrained.pt contains a dict {"model": model.state_dict()},
|
||||||
|
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -279,6 +291,12 @@ def main():
|
|||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.load_state_dict(average_checkpoints(filenames))
|
model.load_state_dict(average_checkpoints(filenames))
|
||||||
|
|
||||||
|
if params.export:
|
||||||
|
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||||
|
torch.save(
|
||||||
|
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||||
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
209
egs/yesno/ASR/tdnn/pretrained.py
Executable file
209
egs/yesno/ASR/tdnn/pretrained.py
Executable file
@ -0,0 +1,209 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import kaldifeat
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from model import Tdnn
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
from icefall.decode import get_lattice, one_best_decoding
|
||||||
|
from icefall.utils import AttributeDict, get_texts
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint. "
|
||||||
|
"The checkpoint is assumed to be saved by "
|
||||||
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--words-file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to words.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--HLG", type=str, required=True, help="Path to HLG.pt."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"feature_dim": 23,
|
||||||
|
"num_classes": 4, # [<blk>, N, SIL, Y]
|
||||||
|
"sample_rate": 8000,
|
||||||
|
"search_beam": 20,
|
||||||
|
"output_beam": 8,
|
||||||
|
"min_active_states": 30,
|
||||||
|
"max_active_states": 10000,
|
||||||
|
"use_double_scores": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info("Creating model")
|
||||||
|
|
||||||
|
model = Tdnn(
|
||||||
|
num_features=params.feature_dim,
|
||||||
|
num_classes=params.num_classes,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||||
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
|
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||||
|
HLG = HLG.to(device)
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = params.sample_rate
|
||||||
|
opts.mel_opts.num_bins = params.feature_dim
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {params.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||||
|
)
|
||||||
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
|
||||||
|
features = pad_sequence(
|
||||||
|
features, batch_first=True, padding_value=math.log(1e-10)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: We don't use key padding mask for attention during decoding
|
||||||
|
with torch.no_grad():
|
||||||
|
nnet_output = model(features)
|
||||||
|
|
||||||
|
batch_size = nnet_output.shape[0]
|
||||||
|
supervision_segments = torch.tensor(
|
||||||
|
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=nnet_output,
|
||||||
|
HLG=HLG,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
|
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||||
|
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
|
words = " ".join(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user