Update docs and remove unnecessary arguments (#42)

* Fix typo in docs

* Update docs and remove unnecessary arguments

* Fix code style
This commit is contained in:
Wei Kang 2021-09-13 18:28:57 +08:00 committed by GitHub
parent f792b466bf
commit 24656e9749
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 184 additions and 365 deletions

View File

@ -45,7 +45,7 @@ For example,
.. code-block:: bash
$ cd egs/yesno/ASR
$ cd egs/librispeech/ASR
$ ./prepare.sh --stage 0 --stop-stage 0
means to run only stage 0.
@ -171,7 +171,7 @@ The following options are used quite often:
Pre-configured options
~~~~~~~~~~~~~~~~~~~~~~
There are some training options, e.g., learning rate,
There are some training options, e.g., weight decay,
number of warmup steps, results dir, etc,
that are not passed from the commandline.
They are pre-configured by the function ``get_params()`` in
@ -346,6 +346,10 @@ The following commands describe how to download the pre-trained model:
You have to use ``git lfs`` to download the pre-trained model.
.. CAUTION::
In order to use this pre-trained model, your k2 version has to be v1.7 or later.
After downloading, you will have the following files:
.. code-block:: bash
@ -409,9 +413,9 @@ After downloading, you will have the following files:
It contains some test sound files from LibriSpeech ``test-clean`` dataset.
- `test_waves/trans.txt`
- ``test_waves/trans.txt``
It contains the reference transcripts for the sound files in `test_waves/`.
It contains the reference transcripts for the sound files in ``test_waves/``.
The information of the test sound files is listed below:

View File

@ -153,10 +153,6 @@ Some commonly used options are:
will save the averaged model to ``tdnn_lstm_ctc/exp/pretrained.pt``.
See :ref:`tdnn_lstm_ctc use a pre-trained model` for how to use it.
.. HINT::
There are several decoding methods provided in `tdnn_lstm_ctc/decode.py <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/tdnn_lstm_ctc/train.py>`_, you can change the decoding method by modifying ``method`` parameter in function ``get_params()``.
.. _tdnn_lstm_ctc use a pre-trained model:
@ -168,6 +164,16 @@ We have uploaded the pre-trained model to
The following shows you how to use the pre-trained model.
Install kaldifeat
~~~~~~~~~~~~~~~~~
`kaldifeat <https://github.com/csukuangfj/kaldifeat>`_ is used to
extract features for a single sound file or multiple sound files
at the same time.
Please refer to `<https://github.com/csukuangfj/kaldifeat>`_ for installation.
Download the pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -183,6 +189,10 @@ Download the pre-trained model
You have to use ``git lfs`` to download the pre-trained model.
.. CAUTION::
In order to use this pre-trained model, your k2 version has to be v1.7 or later.
After downloading, you will have the following files:
.. code-block:: bash
@ -212,13 +222,75 @@ After downloading, you will have the following files:
6 directories, 10 files
**File descriptions**:
Download kaldifeat
~~~~~~~~~~~~~~~~~~
- ``data/lang_phone/HLG.pt``
It is the decoding graph.
- ``data/lang_phone/tokens.txt``
It contains tokens and their IDs.
- ``data/lang_phone/words.txt``
It contains words and their IDs.
- ``data/lm/G_4_gram.pt``
It is a 4-gram LM, useful for LM rescoring.
- ``exp/pretrained.pt``
It contains pre-trained model parameters, obtained by averaging
checkpoints from ``epoch-14.pt`` to ``epoch-19.pt``.
Note: We have removed optimizer ``state_dict`` to reduce file size.
- ``test_waves/*.flac``
It contains some test sound files from LibriSpeech ``test-clean`` dataset.
- ``test_waves/trans.txt``
It contains the reference transcripts for the sound files in ``test_waves/``.
The information of the test sound files is listed below:
.. code-block:: bash
$ soxi tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/*.flac
Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors
File Size : 116k
Bit Rate : 140k
Sample Encoding: 16-bit FLAC
Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors
File Size : 343k
Bit Rate : 164k
Sample Encoding: 16-bit FLAC
Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors
File Size : 105k
Bit Rate : 174k
Sample Encoding: 16-bit FLAC
Total Duration of 3 files: 00:00:28.16
`kaldifeat <https://github.com/csukuangfj/kaldifeat>`_ is used for extracting
features from a single or multiple sound files. Please refer to
`<https://github.com/csukuangfj/kaldifeat>`_ to install ``kaldifeat`` first.
Inference with a pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -56,8 +56,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
is_espnet_structure: bool = False,
mmi_loss: bool = True,
use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
@ -72,7 +70,6 @@ class Conformer(Transformer):
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
mmi_loss=mmi_loss,
use_feat_batchnorm=use_feat_batchnorm,
)
@ -85,12 +82,10 @@ class Conformer(Transformer):
dropout,
cnn_module_kernel,
normalize_before,
is_espnet_structure,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
self.is_espnet_structure = is_espnet_structure
if self.normalize_before and self.is_espnet_structure:
if self.normalize_before:
self.after_norm = nn.LayerNorm(d_model)
else:
# Note: TorchScript detects that self.after_norm could be used inside forward()
@ -125,7 +120,7 @@ class Conformer(Transformer):
mask = mask.to(x.device)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
if self.normalize_before and self.is_espnet_structure:
if self.normalize_before:
x = self.after_norm(x)
return x, mask
@ -159,11 +154,10 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
is_espnet_structure: bool = False,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential(
@ -436,7 +430,6 @@ class RelPositionMultiheadAttention(nn.Module):
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_espnet_structure: bool = False,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
@ -459,8 +452,6 @@ class RelPositionMultiheadAttention(nn.Module):
self._reset_parameters()
self.is_espnet_structure = is_espnet_structure
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
@ -690,9 +681,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
if not self.is_espnet_structure:
q = q * scaling
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
@ -785,14 +773,9 @@ class RelPositionMultiheadAttention(nn.Module):
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
if not self.is_espnet_structure:
attn_output_weights = (
matrix_ac + matrix_bd
) # (batch, head, time1, time2)
else:
attn_output_weights = (
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = (
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1

View File

@ -137,15 +137,15 @@ def get_params() -> AttributeDict:
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"),
# parameters for conformer
"subsampling_factor": 4,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True,
# parameters for decoding
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
@ -538,8 +538,6 @@ def main():
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)

View File

@ -173,17 +173,17 @@ def get_parser():
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"subsampling_factor": 4,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"feature_dim": 80,
"nhead": 8,
"num_classes": 5000,
"sample_rate": 16000,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True,
# parameters for decoding
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
@ -241,8 +241,6 @@ def main():
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -111,15 +112,6 @@ def get_params() -> AttributeDict:
- lang_dir: It contains language related input files such as
"lexicon.txt"
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@ -138,23 +130,40 @@ def get_params() -> AttributeDict:
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model.
- head: Number of heads of multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss
- use_double_scores: It is used in k2.ctc_loss
- weight_decay: The weight_decay for the optimizer.
- lr_factor: The lr_factor for Noam optimizer.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"feature_dim": 80,
"weight_decay": 1e-6,
"subsampling_factor": 4,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
@ -163,17 +172,20 @@ def get_params() -> AttributeDict:
"log_interval": 10,
"reset_interval": 200,
"valid_interval": 3000,
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
"accum_grad": 1,
"att_rate": 0.7,
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"use_feat_batchnorm": True,
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 6,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True,
# parameters for loss
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
"att_rate": 0.7,
# parameters for Noam
"weight_decay": 1e-6,
"lr_factor": 5.0,
"warm_step": 80000,
}
@ -646,8 +658,6 @@ def run(rank, world_size, args):
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)

View File

@ -41,7 +41,6 @@ class Transformer(nn.Module):
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
mmi_loss: bool = True,
use_feat_batchnorm: bool = False,
) -> None:
"""
@ -70,7 +69,6 @@ class Transformer(nn.Module):
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
mmi_loss:
use_feat_batchnorm:
True to use batchnorm for the input layer.
"""
@ -122,14 +120,9 @@ class Transformer(nn.Module):
)
if num_decoder_layers > 0:
if mmi_loss:
self.decoder_num_class = (
self.num_classes + 1
) # +1 for the sos/eos symbol
else:
self.decoder_num_class = (
self.num_classes
) # bpe model already has sos/eos symbol
self.decoder_num_class = (
self.num_classes
) # bpe model already has sos/eos symbol
self.decoder_embed = nn.Embedding(
num_embeddings=self.decoder_num_class, embedding_dim=d_model

View File

@ -1,270 +0,0 @@
# How to use a pre-trained model to transcribe a sound file or multiple sound files
(See the bottom of this document for the link to a colab notebook.)
You need to prepare 4 files:
- a model checkpoint file, e.g., epoch-20.pt
- HLG.pt, the decoding graph
- words.txt, the word symbol table
- a sound file, whose sampling rate has to be 16 kHz.
Supported formats are those supported by `torchaudio.load()`,
e.g., wav and flac.
Also, you need to install `kaldifeat`. Please refer to
<https://github.com/csukuangfj/kaldifeat> for installation.
```bash
./tdnn_lstm_ctc/pretrained.py --help
```
displays the help information.
## HLG decoding
Once you have the above files ready and have `kaldifeat` installed,
you can run:
```bash
./tdnn_lstm_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
/path/to/your/sound.wav
```
and you will see the transcribed result.
If you want to transcribe multiple files at the same time, you can use:
```bash
./tdnn_lstm_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav
```
**Note**: This is the fastest decoding method.
## HLG decoding + LM rescoring
`./tdnn_lstm_ctc/pretrained.py` also supports `whole lattice LM rescoring`.
To use whole lattice LM rescoring, you also need the following files:
- G.pt, e.g., `data/lm/G_4_gram.pt` if you have run `./prepare.sh`
The command to run decoding with LM rescoring is:
```bash
./tdnn_lstm_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
--method whole-lattice-rescoring \
--G data/lm/G_4_gram.pt \
--ngram-lm-scale 0.8 \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav
```
# Decoding with a pre-trained model in action
We have uploaded a pre-trained model to <https://huggingface.co/pkufool/icefall_asr_librispeech_tdnn-lstm_ctc>
The following shows the steps about the usage of the provided pre-trained model.
### (1) Download the pre-trained model
```bash
sudo apt-get install git-lfs
cd /path/to/icefall/egs/librispeech/ASR
git lfs install
mkdir tmp
cd tmp
git clone https://huggingface.co/pkufool/icefall_asr_librispeech_tdnn-lstm_ctc
```
**CAUTION**: You have to install `git-lfs` to download the pre-trained model.
You will find the following files:
```
tmp/
`-- icefall_asr_librispeech_tdnn-lstm_ctc
|-- README.md
|-- data
| |-- lang_phone
| | |-- HLG.pt
| | |-- tokens.txt
| | `-- words.txt
| `-- lm
| `-- G_4_gram.pt
|-- exp
| `-- pretrained.pt
`-- test_wavs
|-- 1089-134686-0001.flac
|-- 1221-135766-0001.flac
|-- 1221-135766-0002.flac
`-- trans.txt
6 directories, 10 files
```
**File descriptions**:
- `data/lang_phone/HLG.pt`
It is the decoding graph.
- `data/lang_phone/tokens.txt`
It contains tokens and their IDs.
- `data/lang_phone/words.txt`
It contains words and their IDs.
- `data/lm/G_4_gram.pt`
It is a 4-gram LM, useful for LM rescoring.
- `exp/pretrained.pt`
It contains pre-trained model parameters, obtained by averaging
checkpoints from `epoch-14.pt` to `epoch-19.pt`.
Note: We have removed optimizer `state_dict` to reduce file size.
- `test_waves/*.flac`
It contains some test sound files from LibriSpeech `test-clean` dataset.
- `test_waves/trans.txt`
It contains the reference transcripts for the sound files in `test_waves/`.
The information of the test sound files is listed below:
```
$ soxi tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/*.flac
Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors
File Size : 116k
Bit Rate : 140k
Sample Encoding: 16-bit FLAC
Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors
File Size : 343k
Bit Rate : 164k
Sample Encoding: 16-bit FLAC
Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors
File Size : 105k
Bit Rate : 174k
Sample Encoding: 16-bit FLAC
Total Duration of 3 files: 00:00:28.16
```
### (2) Use HLG decoding
```bash
cd /path/to/icefall/egs/librispeech/ASR
./tdnn_lstm_ctc/pretrained.py \
--checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \
--words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \
--HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac
```
The output is given below:
```
2021-08-24 16:57:13,315 INFO [pretrained.py:168] device: cuda:0
2021-08-24 16:57:13,315 INFO [pretrained.py:170] Creating model
2021-08-24 16:57:18,331 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt
2021-08-24 16:57:27,581 INFO [pretrained.py:199] Constructing Fbank computer
2021-08-24 16:57:27,584 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac']
2021-08-24 16:57:27,599 INFO [pretrained.py:215] Decoding started
2021-08-24 16:57:27,791 INFO [pretrained.py:245] Use HLG decoding
2021-08-24 16:57:28,098 INFO [pretrained.py:266]
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done
```
### (3) Use HLG decoding + LM rescoring
```bash
./tdnn_lstm_ctc/pretrained.py \
--checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \
--words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \
--HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \
--method whole-lattice-rescoring \
--G ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.8 \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac
```
The output is:
```
2021-08-24 16:39:24,725 INFO [pretrained.py:168] device: cuda:0
2021-08-24 16:39:24,725 INFO [pretrained.py:170] Creating model
2021-08-24 16:39:29,403 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt
2021-08-24 16:39:40,631 INFO [pretrained.py:190] Loading G from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt
2021-08-24 16:39:53,098 INFO [pretrained.py:199] Constructing Fbank computer
2021-08-24 16:39:53,107 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac']
2021-08-24 16:39:53,121 INFO [pretrained.py:215] Decoding started
2021-08-24 16:39:53,443 INFO [pretrained.py:250] Use HLG decoding + LM rescoring
2021-08-24 16:39:54,010 INFO [pretrained.py:266]
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done
```
**NOTE**: We provide a colab notebook for demonstration.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
Due to limited memory provided by Colab, you have to upgrade to Colab Pro to run `HLG decoding + LM rescoring`.
Otherwise, you can only run `HLG decoding` with Colab.

View File

@ -67,6 +67,47 @@ def get_parser():
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--method",
type=str,
default="whole-lattice-rescoring",
help="""Decoding method.
Supported values are:
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring
""",
)
parser.add_argument(
"--lattice-score-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--export",
type=str2bool,
@ -93,16 +134,6 @@ def get_params() -> AttributeDict:
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
# Possible values for method:
# - 1best
# - nbest
# - nbest-rescoring
# - whole-lattice-rescoring
"method": "whole-lattice-rescoring",
# "method": "1best",
# "method": "nbest",
# num_paths is used when method is "nbest" and "nbest-rescoring"
"num_paths": 100,
}
)
return params