mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 18:54:18 +00:00
Merge branch 'k2-fsa:master' into master
This commit is contained in:
commit
13d972c628
@ -303,6 +303,10 @@ The commonly used options are:
|
|||||||
|
|
||||||
$ cd egs/librispeech/ASR
|
$ cd egs/librispeech/ASR
|
||||||
$ ./conformer_ctc/decode.py --method ctc-decoding --max-duration 300
|
$ ./conformer_ctc/decode.py --method ctc-decoding --max-duration 300
|
||||||
|
# Caution: The above command is tested with a model with vocab size 500.
|
||||||
|
# The default settings in the master will not work.
|
||||||
|
# Please see https://github.com/k2-fsa/icefall/issues/103
|
||||||
|
# We will fix it later and delete this note.
|
||||||
|
|
||||||
And the following command uses attention decoder for rescoring:
|
And the following command uses attention decoder for rescoring:
|
||||||
|
|
||||||
@ -328,6 +332,8 @@ Usage:
|
|||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
$ cd egs/librispeech/ASR
|
$ cd egs/librispeech/ASR
|
||||||
|
# NOTE: Tested with a model with vocab size 500.
|
||||||
|
# It won't work for a model with vocab size 5000.
|
||||||
$ ./conformer_ctc/decode.py \
|
$ ./conformer_ctc/decode.py \
|
||||||
--epoch 25 \
|
--epoch 25 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -399,7 +405,7 @@ Download the pre-trained model
|
|||||||
|
|
||||||
The following commands describe how to download the pre-trained model:
|
The following commands describe how to download the pre-trained model:
|
||||||
|
|
||||||
.. code-block::
|
.. code-block:: bash
|
||||||
|
|
||||||
$ cd egs/librispeech/ASR
|
$ cd egs/librispeech/ASR
|
||||||
$ mkdir tmp
|
$ mkdir tmp
|
||||||
@ -410,10 +416,23 @@ The following commands describe how to download the pre-trained model:
|
|||||||
.. CAUTION::
|
.. CAUTION::
|
||||||
|
|
||||||
You have to use ``git lfs`` to download the pre-trained model.
|
You have to use ``git lfs`` to download the pre-trained model.
|
||||||
|
Otherwise, you will have the following issue when running ``decode.py``:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
_pickle.UnpicklingError: invalid load key, 'v'
|
||||||
|
|
||||||
|
To fix that issue, please use:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd icefall_asr_librispeech_conformer_ctc
|
||||||
|
git lfs pull
|
||||||
|
|
||||||
|
|
||||||
.. CAUTION::
|
.. CAUTION::
|
||||||
|
|
||||||
In order to use this pre-trained model, your k2 version has to be v1.7 or later.
|
In order to use this pre-trained model, your k2 version has to be v1.9 or later.
|
||||||
|
|
||||||
After downloading, you will have the following files:
|
After downloading, you will have the following files:
|
||||||
|
|
||||||
|
@ -362,22 +362,25 @@ def compute_loss(
|
|||||||
|
|
||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
if hasattr(model, "module"):
|
mmodel = model.module if hasattr(model, "module") else model
|
||||||
att_loss = model.module.decoder_forward(
|
# Note: We need to generate an unsorted version of token_ids
|
||||||
encoder_memory,
|
# `encode_supervisions()` called above sorts text, but
|
||||||
memory_mask,
|
# encoder_memory and memory_mask are not sorted, so we
|
||||||
token_ids=token_ids,
|
# use an unsorted version `supervisions["text"]` to regenerate
|
||||||
sos_id=graph_compiler.sos_id,
|
# the token_ids
|
||||||
eos_id=graph_compiler.eos_id,
|
#
|
||||||
)
|
# See https://github.com/k2-fsa/icefall/issues/97
|
||||||
else:
|
# for more details
|
||||||
att_loss = model.decoder_forward(
|
unsorted_token_ids = graph_compiler.texts_to_ids(
|
||||||
encoder_memory,
|
supervisions["text"]
|
||||||
memory_mask,
|
)
|
||||||
token_ids=token_ids,
|
att_loss = mmodel.decoder_forward(
|
||||||
sos_id=graph_compiler.sos_id,
|
encoder_memory,
|
||||||
eos_id=graph_compiler.eos_id,
|
memory_mask,
|
||||||
)
|
token_ids=unsorted_token_ids,
|
||||||
|
sos_id=graph_compiler.sos_id,
|
||||||
|
eos_id=graph_compiler.eos_id,
|
||||||
|
)
|
||||||
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
||||||
else:
|
else:
|
||||||
loss = ctc_loss
|
loss = ctc_loss
|
||||||
|
@ -394,24 +394,16 @@ def compute_loss(
|
|||||||
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
||||||
|
|
||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
token_ids = graph_compiler.texts_to_ids(texts)
|
token_ids = graph_compiler.texts_to_ids(supervisions["text"])
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
if hasattr(model, "module"):
|
mmodel = model.module if hasattr(model, "module") else model
|
||||||
att_loss = model.module.decoder_forward(
|
att_loss = mmodel.decoder_forward(
|
||||||
encoder_memory,
|
encoder_memory,
|
||||||
memory_mask,
|
memory_mask,
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
sos_id=graph_compiler.sos_id,
|
sos_id=graph_compiler.sos_id,
|
||||||
eos_id=graph_compiler.eos_id,
|
eos_id=graph_compiler.eos_id,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
att_loss = model.decoder_forward(
|
|
||||||
encoder_memory,
|
|
||||||
memory_mask,
|
|
||||||
token_ids=token_ids,
|
|
||||||
sos_id=graph_compiler.sos_id,
|
|
||||||
eos_id=graph_compiler.eos_id,
|
|
||||||
)
|
|
||||||
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
|
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
|
||||||
else:
|
else:
|
||||||
loss = mmi_loss
|
loss = mmi_loss
|
||||||
|
@ -394,24 +394,16 @@ def compute_loss(
|
|||||||
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
||||||
|
|
||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
token_ids = graph_compiler.texts_to_ids(texts)
|
token_ids = graph_compiler.texts_to_ids(supervisions["text"])
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
if hasattr(model, "module"):
|
mmodel = model.module if hasattr(model, "module") else model
|
||||||
att_loss = model.module.decoder_forward(
|
att_loss = mmodel.decoder_forward(
|
||||||
encoder_memory,
|
encoder_memory,
|
||||||
memory_mask,
|
memory_mask,
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
sos_id=graph_compiler.sos_id,
|
sos_id=graph_compiler.sos_id,
|
||||||
eos_id=graph_compiler.eos_id,
|
eos_id=graph_compiler.eos_id,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
att_loss = model.decoder_forward(
|
|
||||||
encoder_memory,
|
|
||||||
memory_mask,
|
|
||||||
token_ids=token_ids,
|
|
||||||
sos_id=graph_compiler.sos_id,
|
|
||||||
eos_id=graph_compiler.eos_id,
|
|
||||||
)
|
|
||||||
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
|
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
|
||||||
else:
|
else:
|
||||||
loss = mmi_loss
|
loss = mmi_loss
|
||||||
|
@ -224,6 +224,7 @@ class Nbest(object):
|
|||||||
else:
|
else:
|
||||||
word_seq = lattice.aux_labels.index(path)
|
word_seq = lattice.aux_labels.index(path)
|
||||||
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||||
|
word_seq = word_seq.remove_values_leq(0)
|
||||||
|
|
||||||
# Each utterance has `num_paths` paths but some of them transduces
|
# Each utterance has `num_paths` paths but some of them transduces
|
||||||
# to the same word sequence, so we need to remove repeated word
|
# to the same word sequence, so we need to remove repeated word
|
||||||
@ -732,6 +733,12 @@ def rescore_with_whole_lattice(
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
|
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
|
||||||
)
|
)
|
||||||
|
logging.info(
|
||||||
|
"This OOM is not an error. You can ignore it. "
|
||||||
|
"If your model does not converge well, or --max-duration "
|
||||||
|
"is too large, or the input sound file is difficult to "
|
||||||
|
"decode, you will meet this exception."
|
||||||
|
)
|
||||||
|
|
||||||
# NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here
|
# NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here
|
||||||
# to avoid OOM. You may need to fine tune it.
|
# to avoid OOM. You may need to fine tune it.
|
||||||
@ -864,6 +871,7 @@ def rescore_with_attention_decoder(
|
|||||||
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||||
else:
|
else:
|
||||||
ngram_lm_scale_list = [ngram_lm_scale]
|
ngram_lm_scale_list = [ngram_lm_scale]
|
||||||
|
|
||||||
@ -871,6 +879,7 @@ def rescore_with_attention_decoder(
|
|||||||
attention_scale_list = [0.01, 0.05, 0.08]
|
attention_scale_list = [0.01, 0.05, 0.08]
|
||||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||||
else:
|
else:
|
||||||
attention_scale_list = [attention_scale]
|
attention_scale_list = [attention_scale]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user