Add more doc for the yesno recipe.

This commit is contained in:
Fangjun Kuang 2021-08-24 13:50:59 +08:00
parent 39554781b2
commit 49a2b4a9de
6 changed files with 514 additions and 29 deletions

View File

@ -3,16 +3,17 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
.. image:: _static/logo.png
:alt: icefall logo
:width: 100px
:align: center
:target: https://github.com/k2-fsa/icefall
icefall
=======
Documentation for `icefall <https://github.com/k2-fsa/icefall>`, containing
.. image:: _static/logo.png
:alt: icefall logo
:width: 168px
:align: center
:target: https://github.com/k2-fsa/icefall
Documentation for `icefall <https://github.com/k2-fsa/icefall>`_, containing
speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
.. toctree::

View File

@ -464,6 +464,6 @@ The decoding log is:
2021-08-23 19:35:30,573 INFO [decode.py:236] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
2021-08-23 19:35:30,573 INFO [decode.py:299] Done!
Congratulations! You have successfully setup the environment and have run the first recipe in icefall.
**Congratulations!** You have successfully setup the environment and have run the first recipe in ``icefall``.
Have fun with icefall!
Have fun with ``icefall``!

View File

@ -1,14 +1,13 @@
Recipes
=======
This page contains various recipes in icefall.
This page contains various recipes in ``icefall``.
Currently, only speech recognition recipes are provided.
We may add recipes for other tasks in the future.
We may add recipes for other tasks as well in the future.
.. we put the yesno recipe as the first recipe since it is the simplest
.. recipe.
.. Other recipes are sorted alphabetically
.. we put the yesno recipe as the first recipe since it is the simplest one.
.. Other recipes are listed in a alphabetical order.
.. toctree::
:maxdepth: 2

View File

@ -1,7 +1,43 @@
yesno
=====
This page shows you how to run the ``yesno`` recipe.
This page shows you how to run the ``yesno`` recipe. It contains:
- (1) Prepare data for training
- (2) Train a TDNN model
- (a) View text format logs and visualize TensorBoard logs
- (b) Select device type, i.e., CPU and GPU, for training
- (c) Change training options
- (d) Resume training from a checkpoint
- (3) Decode with a trained model
- (a) Select a checkpoint for decoding
- (b) Model averaging
- (4) Colab notebook
- (a) It shows you step by step how to setup the environment, how to do training,
and how to do decoding
- (b) How to use a pre-trained model
- (5) Inference with a pre-trained model
- (a) Download a pre-trained model, provided by us
- (b) Decode a single sound file with a pre-trained model
- (c) Decode multiple sound files at the same time
It does **NOT** show you:
- (1) How to train with multiple GPUs
The ``yesno`` dataset is so small that CPU is more than enough
for training as well as for decoding.
- (2) How to use LM rescoring for decoding
The dataset does not have an LM for rescoring.
.. HINT::
@ -11,8 +47,8 @@ This page shows you how to run the ``yesno`` recipe.
.. HINT::
You **don't** need a **GPU** to run this recipe. It can be run on a **CPU**.
The training time takes less than 30 **seconds** and you will get
the following WER::
The training part takes less than 30 **seconds** on a CPU and you will get
the following WER at the end::
[test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
@ -24,7 +60,7 @@ Data preparation
$ cd egs/yesno/ASR
$ ./prepare.sh
The script ``./prepare.sh`` handles the data preparation for you, automagically.
The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
All you need to do is to run it.
The data preparation contains several stages, you can use the following two
@ -74,7 +110,7 @@ In ``tdnn/exp``, you will find the following files:
- ``epoch-0.pt``, ``epoch-1.pt``, ...
These are checkpoint files, containing model parameters and optimizer ``state_dict``.
These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
.. code-block:: bash
@ -123,11 +159,6 @@ In ``tdnn/exp``, you will find the following files:
you saw printed to the console during training.
To see available training options, you can use:
.. code-block:: bash
$ ./tdnn/train.py --help
.. NOTE::
@ -152,6 +183,18 @@ To see available training options, you can use:
If you don't have GPUs, then you don't need to
run ``export CUDA_VISIBLE_DEVICES=""``.
To see available training options, you can use:
.. code-block:: bash
$ ./tdnn/train.py --help
Other training options, e.g., learning rate, results dir, etc., are
pre-configured in the function ``get_params()``
in `tdnn/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/yesno/ASR/tdnn/train.py>`_.
Normally, you don't need to change them. You can change them by modifying the code, if
you want.
Decoding
--------
@ -169,6 +212,225 @@ You will see the WER in the output log.
Decoded results are saved in ``tdnn/exp``.
.. code-block:: bash
$ ./tdnn/decode.py --help
shows you the available decoding options.
Some commonly used options are:
- ``--epoch``
You can select which checkpoint to be used for decoding.
For instance, ``./tdnn/decode.py --epoch 10`` means to use
``./tdnn/exp/epoch-10.pt`` for decoding.
- ``--avg``
It's related to model averaging. It specifies number of checkpoints
to be averaged. The averaged model is used for decoding.
For example, the following command:
.. code-block:: bash
$ ./tdnn/decode.py --epoch 10 --avg 3
uses the average of ``epoch-8.pt``, ``epoch-9.pt`` and ``epoch-10.pt``
for decoding.
- ``--export``
If it is ``True``, i.e., ``./tdnn/decode.py --export 1``, the code
will save the averaged model to ``tdnn/exp/pretrained.pt``.
See :ref:`yesno use a pre-trained model` for how to use it.
.. _yesno use a pre-trained model:
Pre-trained Model
-----------------
We have uploaded the pre-trained model to
`<https://huggingface.co/csukuangfj/icefall_asr_yesno_tdnn>`_.
The following shows you how to use the pre-trained model.
Download the pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/yesno/ASR
$ mkdir tmp
$ cd tmp
$ git lfs install
$ git clone https://huggingface.co/csukuangfj/icefall_asr_yesno_tdnn
.. CAUTION::
You have to use ``git lfs`` to download the pre-trained model.
After downloading, you will have the following files:
.. code-block:: bash
$ cd egs/yesno/ASR
$ tree tmp
.. code-block:: bash
tmp/
`-- icefall_asr_yesno_tdnn
|-- README.md
|-- lang_phone
| |-- HLG.pt
| |-- L.pt
| |-- L_disambig.pt
| |-- Linv.pt
| |-- lexicon.txt
| |-- lexicon_disambig.txt
| |-- tokens.txt
| `-- words.txt
|-- lm
| |-- G.arpa
| `-- G.fst.txt
|-- pretrained.pt
`-- test_waves
|-- 0_0_0_1_0_0_0_1.wav
|-- 0_0_1_0_0_0_1_0.wav
|-- 0_0_1_0_0_1_1_1.wav
|-- 0_0_1_0_1_0_0_1.wav
|-- 0_0_1_1_0_0_0_1.wav
|-- 0_0_1_1_0_1_1_0.wav
|-- 0_0_1_1_1_0_0_0.wav
|-- 0_0_1_1_1_1_0_0.wav
|-- 0_1_0_0_0_1_0_0.wav
|-- 0_1_0_0_1_0_1_0.wav
|-- 0_1_0_1_0_0_0_0.wav
|-- 0_1_0_1_1_1_0_0.wav
|-- 0_1_1_0_0_1_1_1.wav
|-- 0_1_1_1_0_0_1_0.wav
|-- 0_1_1_1_1_0_1_0.wav
|-- 1_0_0_0_0_0_0_0.wav
|-- 1_0_0_0_0_0_1_1.wav
|-- 1_0_0_1_0_1_1_1.wav
|-- 1_0_1_1_0_1_1_1.wav
|-- 1_0_1_1_1_1_0_1.wav
|-- 1_1_0_0_0_1_1_1.wav
|-- 1_1_0_0_1_0_1_1.wav
|-- 1_1_0_1_0_1_0_0.wav
|-- 1_1_0_1_1_0_0_1.wav
|-- 1_1_0_1_1_1_1_0.wav
|-- 1_1_1_0_0_1_0_1.wav
|-- 1_1_1_0_1_0_1_0.wav
|-- 1_1_1_1_0_0_1_0.wav
|-- 1_1_1_1_1_0_0_0.wav
`-- 1_1_1_1_1_1_1_1.wav
4 directories, 42 files
.. code-block:: bash
$ soxi tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav
Input File : 'tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav'
Channels : 1
Sample Rate : 8000
Precision : 16-bit
Duration : 00:00:06.76 = 54080 samples ~ 507 CDDA sectors
File Size : 108k
Bit Rate : 128k
Sample Encoding: 16-bit Signed Integer PCM
- ``0_0_1_0_1_0_0_1.wav``
0 means No; 1 means Yes. No and Yes are not in English,
but in `Hebrew <https://en.wikipedia.org/wiki/Hebrew_language>`_.
So this file contains ``NO NO YES NO YES NO NO YES``.
Download kaldifeat
~~~~~~~~~~~~~~~~~~
`kaldifeat <https://github.com/csukuangfj/kaldifeat>`_ is used for extracting
features from a single or multiple sound files. Please refer to
`<https://github.com/csukuangfj/kaldifeat>`_ to install ``kaldifeat`` first.
Inference with a pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ cd egs/yesno/ASR
$ ./tdnn/pretrained.py --help
shows the usage information of ``./tdnn/pretrained.py``.
To decode a single file, we can use:
.. code-block:: bash
./tdnn/pretrained.py \
--checkpoint ./tmp/icefall_asr_yesno_tdnn/pretrained.pt \
--words-file ./tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt \
--HLG ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt \
./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav
The output is:
.. code-block::
2021-08-24 12:22:51,621 INFO [pretrained.py:119] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'checkpoint': './tmp/icefall_asr_yesno_tdnn/pretrained.pt', 'words_file': './tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt', 'HLG': './tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt', 'sound_files': ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav']}
2021-08-24 12:22:51,645 INFO [pretrained.py:125] device: cpu
2021-08-24 12:22:51,645 INFO [pretrained.py:127] Creating model
2021-08-24 12:22:51,650 INFO [pretrained.py:139] Loading HLG from ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt
2021-08-24 12:22:51,651 INFO [pretrained.py:143] Constructing Fbank computer
2021-08-24 12:22:51,652 INFO [pretrained.py:153] Reading sound files: ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav']
2021-08-24 12:22:51,684 INFO [pretrained.py:159] Decoding started
2021-08-24 12:22:51,708 INFO [pretrained.py:198]
./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav:
NO NO YES NO YES NO NO YES
2021-08-24 12:22:51,708 INFO [pretrained.py:200] Decoding Done
You can see that for the sound file ``0_0_1_0_1_0_0_1.wav``, the decoding result is
``NO NO YES NO YES NO NO YES``.
To decode **multiple** files at the same time, you can use
.. code-block:: bash
./tdnn/pretrained.py \
--checkpoint ./tmp/icefall_asr_yesno_tdnn/pretrained.pt \
--words-file ./tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt \
--HLG ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt \
./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav \
./tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav
The decoding output is:
.. code-block::
2021-08-24 12:25:20,159 INFO [pretrained.py:119] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'checkpoint': './tmp/icefall_asr_yesno_tdnn/pretrained.pt', 'words_file': './tmp/icefall_asr_yesno_tdnn/lang_phone/words.txt', 'HLG': './tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt', 'sound_files': ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav', './tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav']}
2021-08-24 12:25:20,181 INFO [pretrained.py:125] device: cpu
2021-08-24 12:25:20,181 INFO [pretrained.py:127] Creating model
2021-08-24 12:25:20,185 INFO [pretrained.py:139] Loading HLG from ./tmp/icefall_asr_yesno_tdnn/lang_phone/HLG.pt
2021-08-24 12:25:20,186 INFO [pretrained.py:143] Constructing Fbank computer
2021-08-24 12:25:20,187 INFO [pretrained.py:153] Reading sound files: ['./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav',
'./tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav']
2021-08-24 12:25:20,213 INFO [pretrained.py:159] Decoding started
2021-08-24 12:25:20,287 INFO [pretrained.py:198]
./tmp/icefall_asr_yesno_tdnn/test_waves/0_0_1_0_1_0_0_1.wav:
NO NO YES NO YES NO NO YES
./tmp/icefall_asr_yesno_tdnn/test_waves/1_0_1_1_0_1_1_1.wav:
YES NO YES YES NO YES YES YES
2021-08-24 12:25:20,287 INFO [pretrained.py:200] Decoding Done
You can see again that it decodes correctly.
Colab notebook
--------------
@ -180,8 +442,4 @@ We do provide a colab notebook for this recipe.
:target: https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing
Use a pre-trained model
-----------------------
TODO
**Congratulations!** You have finished the simplest speech recognition recipe in ``icefall``.

View File

@ -20,6 +20,7 @@ from icefall.utils import (
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
@ -44,6 +45,17 @@ def get_parser():
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
return parser
@ -279,6 +291,12 @@ def main():
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
model.to(device)
model.eval()

209
egs/yesno/ASR/tdnn/pretrained.py Executable file
View File

@ -0,0 +1,209 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from model import Tdnn
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 23,
"num_classes": 4, # [<blk>, N, SIL, Y]
"sample_rate": 8000,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = Tdnn(
num_features=params.feature_dim,
num_classes=params.num_classes,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
# Note: We don't use key padding mask for attention during decoding
with torch.no_grad():
nnet_output = model(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()