Merge branch 'k2-fsa:master' into master

This commit is contained in:
Mingshuang Luo 2021-11-08 21:11:00 +08:00 committed by GitHub
commit 13d972c628
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 67 additions and 52 deletions

View File

@ -303,6 +303,10 @@ The commonly used options are:
$ cd egs/librispeech/ASR $ cd egs/librispeech/ASR
$ ./conformer_ctc/decode.py --method ctc-decoding --max-duration 300 $ ./conformer_ctc/decode.py --method ctc-decoding --max-duration 300
# Caution: The above command is tested with a model with vocab size 500.
# The default settings in the master will not work.
# Please see https://github.com/k2-fsa/icefall/issues/103
# We will fix it later and delete this note.
And the following command uses attention decoder for rescoring: And the following command uses attention decoder for rescoring:
@ -328,6 +332,8 @@ Usage:
.. code-block:: bash .. code-block:: bash
$ cd egs/librispeech/ASR $ cd egs/librispeech/ASR
# NOTE: Tested with a model with vocab size 500.
# It won't work for a model with vocab size 5000.
$ ./conformer_ctc/decode.py \ $ ./conformer_ctc/decode.py \
--epoch 25 \ --epoch 25 \
--avg 1 \ --avg 1 \
@ -399,7 +405,7 @@ Download the pre-trained model
The following commands describe how to download the pre-trained model: The following commands describe how to download the pre-trained model:
.. code-block:: .. code-block:: bash
$ cd egs/librispeech/ASR $ cd egs/librispeech/ASR
$ mkdir tmp $ mkdir tmp
@ -410,10 +416,23 @@ The following commands describe how to download the pre-trained model:
.. CAUTION:: .. CAUTION::
You have to use ``git lfs`` to download the pre-trained model. You have to use ``git lfs`` to download the pre-trained model.
Otherwise, you will have the following issue when running ``decode.py``:
.. code-block::
_pickle.UnpicklingError: invalid load key, 'v'
To fix that issue, please use:
.. code-block:: bash
cd icefall_asr_librispeech_conformer_ctc
git lfs pull
.. CAUTION:: .. CAUTION::
In order to use this pre-trained model, your k2 version has to be v1.7 or later. In order to use this pre-trained model, your k2 version has to be v1.9 or later.
After downloading, you will have the following files: After downloading, you will have the following files:

View File

@ -362,22 +362,25 @@ def compute_loss(
if params.att_rate != 0.0: if params.att_rate != 0.0:
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
if hasattr(model, "module"): mmodel = model.module if hasattr(model, "module") else model
att_loss = model.module.decoder_forward( # Note: We need to generate an unsorted version of token_ids
encoder_memory, # `encode_supervisions()` called above sorts text, but
memory_mask, # encoder_memory and memory_mask are not sorted, so we
token_ids=token_ids, # use an unsorted version `supervisions["text"]` to regenerate
sos_id=graph_compiler.sos_id, # the token_ids
eos_id=graph_compiler.eos_id, #
) # See https://github.com/k2-fsa/icefall/issues/97
else: # for more details
att_loss = model.decoder_forward( unsorted_token_ids = graph_compiler.texts_to_ids(
encoder_memory, supervisions["text"]
memory_mask, )
token_ids=token_ids, att_loss = mmodel.decoder_forward(
sos_id=graph_compiler.sos_id, encoder_memory,
eos_id=graph_compiler.eos_id, memory_mask,
) token_ids=unsorted_token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
else: else:
loss = ctc_loss loss = ctc_loss

View File

@ -394,24 +394,16 @@ def compute_loss(
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts) mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
if params.att_rate != 0.0: if params.att_rate != 0.0:
token_ids = graph_compiler.texts_to_ids(texts) token_ids = graph_compiler.texts_to_ids(supervisions["text"])
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
if hasattr(model, "module"): mmodel = model.module if hasattr(model, "module") else model
att_loss = model.module.decoder_forward( att_loss = mmodel.decoder_forward(
encoder_memory, encoder_memory,
memory_mask, memory_mask,
token_ids=token_ids, token_ids=token_ids,
sos_id=graph_compiler.sos_id, sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id, eos_id=graph_compiler.eos_id,
) )
else:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
else: else:
loss = mmi_loss loss = mmi_loss

View File

@ -394,24 +394,16 @@ def compute_loss(
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts) mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
if params.att_rate != 0.0: if params.att_rate != 0.0:
token_ids = graph_compiler.texts_to_ids(texts) token_ids = graph_compiler.texts_to_ids(supervisions["text"])
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
if hasattr(model, "module"): mmodel = model.module if hasattr(model, "module") else model
att_loss = model.module.decoder_forward( att_loss = mmodel.decoder_forward(
encoder_memory, encoder_memory,
memory_mask, memory_mask,
token_ids=token_ids, token_ids=token_ids,
sos_id=graph_compiler.sos_id, sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id, eos_id=graph_compiler.eos_id,
) )
else:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
else: else:
loss = mmi_loss loss = mmi_loss

View File

@ -224,6 +224,7 @@ class Nbest(object):
else: else:
word_seq = lattice.aux_labels.index(path) word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2) word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
word_seq = word_seq.remove_values_leq(0)
# Each utterance has `num_paths` paths but some of them transduces # Each utterance has `num_paths` paths but some of them transduces
# to the same word sequence, so we need to remove repeated word # to the same word sequence, so we need to remove repeated word
@ -732,6 +733,12 @@ def rescore_with_whole_lattice(
logging.info( logging.info(
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}" f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
) )
logging.info(
"This OOM is not an error. You can ignore it. "
"If your model does not converge well, or --max-duration "
"is too large, or the input sound file is difficult to "
"decode, you will meet this exception."
)
# NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here # NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here
# to avoid OOM. You may need to fine tune it. # to avoid OOM. You may need to fine tune it.
@ -864,6 +871,7 @@ def rescore_with_attention_decoder(
ngram_lm_scale_list = [0.01, 0.05, 0.08] ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else: else:
ngram_lm_scale_list = [ngram_lm_scale] ngram_lm_scale_list = [ngram_lm_scale]
@ -871,6 +879,7 @@ def rescore_with_attention_decoder(
attention_scale_list = [0.01, 0.05, 0.08] attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else: else:
attention_scale_list = [attention_scale] attention_scale_list = [attention_scale]