mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Use pre-sorted text to generate token ids for attention decoder. (#98)
* Use pre-sorted text to generate token ids for attention decoder. See https://github.com/k2-fsa/icefall/issues/97 for more details. * Fix typos.
This commit is contained in:
parent
12d647d899
commit
42b437bea6
@ -362,19 +362,22 @@ def compute_loss(
|
|||||||
|
|
||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
if hasattr(model, "module"):
|
mmodel = model.module if hasattr(model, "module") else model
|
||||||
att_loss = model.module.decoder_forward(
|
# Note: We need to generate an unsorted version of token_ids
|
||||||
encoder_memory,
|
# `encode_supervisions()` called above sorts text, but
|
||||||
memory_mask,
|
# encoder_memory and memory_mask are not sorted, so we
|
||||||
token_ids=token_ids,
|
# use an unsorted version `supervisions["text"]` to regenerate
|
||||||
sos_id=graph_compiler.sos_id,
|
# the token_ids
|
||||||
eos_id=graph_compiler.eos_id,
|
#
|
||||||
|
# See https://github.com/k2-fsa/icefall/issues/97
|
||||||
|
# for more details
|
||||||
|
unsorted_token_ids = graph_compiler.texts_to_ids(
|
||||||
|
supervisions["text"]
|
||||||
)
|
)
|
||||||
else:
|
att_loss = mmodel.decoder_forward(
|
||||||
att_loss = model.decoder_forward(
|
|
||||||
encoder_memory,
|
encoder_memory,
|
||||||
memory_mask,
|
memory_mask,
|
||||||
token_ids=token_ids,
|
token_ids=unsorted_token_ids,
|
||||||
sos_id=graph_compiler.sos_id,
|
sos_id=graph_compiler.sos_id,
|
||||||
eos_id=graph_compiler.eos_id,
|
eos_id=graph_compiler.eos_id,
|
||||||
)
|
)
|
||||||
|
@ -394,18 +394,10 @@ def compute_loss(
|
|||||||
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
||||||
|
|
||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
token_ids = graph_compiler.texts_to_ids(texts)
|
token_ids = graph_compiler.texts_to_ids(supervisions["text"])
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
if hasattr(model, "module"):
|
mmodel = model.module if hasattr(model, "module") else model
|
||||||
att_loss = model.module.decoder_forward(
|
att_loss = mmodel.decoder_forward(
|
||||||
encoder_memory,
|
|
||||||
memory_mask,
|
|
||||||
token_ids=token_ids,
|
|
||||||
sos_id=graph_compiler.sos_id,
|
|
||||||
eos_id=graph_compiler.eos_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
att_loss = model.decoder_forward(
|
|
||||||
encoder_memory,
|
encoder_memory,
|
||||||
memory_mask,
|
memory_mask,
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
|
@ -394,18 +394,10 @@ def compute_loss(
|
|||||||
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
||||||
|
|
||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
token_ids = graph_compiler.texts_to_ids(texts)
|
token_ids = graph_compiler.texts_to_ids(supervisions["text"])
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
if hasattr(model, "module"):
|
mmodel = model.module if hasattr(model, "module") else model
|
||||||
att_loss = model.module.decoder_forward(
|
att_loss = mmodel.decoder_forward(
|
||||||
encoder_memory,
|
|
||||||
memory_mask,
|
|
||||||
token_ids=token_ids,
|
|
||||||
sos_id=graph_compiler.sos_id,
|
|
||||||
eos_id=graph_compiler.eos_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
att_loss = model.decoder_forward(
|
|
||||||
encoder_memory,
|
encoder_memory,
|
||||||
memory_mask,
|
memory_mask,
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user