mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 18:54:18 +00:00
Merge branch 'k2-fsa:master' into master
This commit is contained in:
commit
13d972c628
@ -303,6 +303,10 @@ The commonly used options are:
|
||||
|
||||
$ cd egs/librispeech/ASR
|
||||
$ ./conformer_ctc/decode.py --method ctc-decoding --max-duration 300
|
||||
# Caution: The above command is tested with a model with vocab size 500.
|
||||
# The default settings in the master will not work.
|
||||
# Please see https://github.com/k2-fsa/icefall/issues/103
|
||||
# We will fix it later and delete this note.
|
||||
|
||||
And the following command uses attention decoder for rescoring:
|
||||
|
||||
@ -328,6 +332,8 @@ Usage:
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd egs/librispeech/ASR
|
||||
# NOTE: Tested with a model with vocab size 500.
|
||||
# It won't work for a model with vocab size 5000.
|
||||
$ ./conformer_ctc/decode.py \
|
||||
--epoch 25 \
|
||||
--avg 1 \
|
||||
@ -399,7 +405,7 @@ Download the pre-trained model
|
||||
|
||||
The following commands describe how to download the pre-trained model:
|
||||
|
||||
.. code-block::
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd egs/librispeech/ASR
|
||||
$ mkdir tmp
|
||||
@ -410,10 +416,23 @@ The following commands describe how to download the pre-trained model:
|
||||
.. CAUTION::
|
||||
|
||||
You have to use ``git lfs`` to download the pre-trained model.
|
||||
Otherwise, you will have the following issue when running ``decode.py``:
|
||||
|
||||
.. code-block::
|
||||
|
||||
_pickle.UnpicklingError: invalid load key, 'v'
|
||||
|
||||
To fix that issue, please use:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd icefall_asr_librispeech_conformer_ctc
|
||||
git lfs pull
|
||||
|
||||
|
||||
.. CAUTION::
|
||||
|
||||
In order to use this pre-trained model, your k2 version has to be v1.7 or later.
|
||||
In order to use this pre-trained model, your k2 version has to be v1.9 or later.
|
||||
|
||||
After downloading, you will have the following files:
|
||||
|
||||
|
@ -362,19 +362,22 @@ def compute_loss(
|
||||
|
||||
if params.att_rate != 0.0:
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if hasattr(model, "module"):
|
||||
att_loss = model.module.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
mmodel = model.module if hasattr(model, "module") else model
|
||||
# Note: We need to generate an unsorted version of token_ids
|
||||
# `encode_supervisions()` called above sorts text, but
|
||||
# encoder_memory and memory_mask are not sorted, so we
|
||||
# use an unsorted version `supervisions["text"]` to regenerate
|
||||
# the token_ids
|
||||
#
|
||||
# See https://github.com/k2-fsa/icefall/issues/97
|
||||
# for more details
|
||||
unsorted_token_ids = graph_compiler.texts_to_ids(
|
||||
supervisions["text"]
|
||||
)
|
||||
else:
|
||||
att_loss = model.decoder_forward(
|
||||
att_loss = mmodel.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
token_ids=unsorted_token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
|
@ -394,18 +394,10 @@ def compute_loss(
|
||||
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
||||
|
||||
if params.att_rate != 0.0:
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
token_ids = graph_compiler.texts_to_ids(supervisions["text"])
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if hasattr(model, "module"):
|
||||
att_loss = model.module.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
else:
|
||||
att_loss = model.decoder_forward(
|
||||
mmodel = model.module if hasattr(model, "module") else model
|
||||
att_loss = mmodel.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
|
@ -394,18 +394,10 @@ def compute_loss(
|
||||
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
||||
|
||||
if params.att_rate != 0.0:
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
token_ids = graph_compiler.texts_to_ids(supervisions["text"])
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if hasattr(model, "module"):
|
||||
att_loss = model.module.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
else:
|
||||
att_loss = model.decoder_forward(
|
||||
mmodel = model.module if hasattr(model, "module") else model
|
||||
att_loss = mmodel.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
|
@ -224,6 +224,7 @@ class Nbest(object):
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
|
||||
# Each utterance has `num_paths` paths but some of them transduces
|
||||
# to the same word sequence, so we need to remove repeated word
|
||||
@ -732,6 +733,12 @@ def rescore_with_whole_lattice(
|
||||
logging.info(
|
||||
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
|
||||
)
|
||||
logging.info(
|
||||
"This OOM is not an error. You can ignore it. "
|
||||
"If your model does not converge well, or --max-duration "
|
||||
"is too large, or the input sound file is difficult to "
|
||||
"decode, you will meet this exception."
|
||||
)
|
||||
|
||||
# NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here
|
||||
# to avoid OOM. You may need to fine tune it.
|
||||
@ -864,6 +871,7 @@ def rescore_with_attention_decoder(
|
||||
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||
else:
|
||||
ngram_lm_scale_list = [ngram_lm_scale]
|
||||
|
||||
@ -871,6 +879,7 @@ def rescore_with_attention_decoder(
|
||||
attention_scale_list = [0.01, 0.05, 0.08]
|
||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||
else:
|
||||
attention_scale_list = [attention_scale]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user