mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Use pre-sorted text to generate token ids for attention decoder. (#98)
* Use pre-sorted text to generate token ids for attention decoder. See https://github.com/k2-fsa/icefall/issues/97 for more details. * Fix typos.
This commit is contained in:
parent
12d647d899
commit
42b437bea6
@ -362,22 +362,25 @@ def compute_loss(
|
||||
|
||||
if params.att_rate != 0.0:
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if hasattr(model, "module"):
|
||||
att_loss = model.module.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
else:
|
||||
att_loss = model.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
mmodel = model.module if hasattr(model, "module") else model
|
||||
# Note: We need to generate an unsorted version of token_ids
|
||||
# `encode_supervisions()` called above sorts text, but
|
||||
# encoder_memory and memory_mask are not sorted, so we
|
||||
# use an unsorted version `supervisions["text"]` to regenerate
|
||||
# the token_ids
|
||||
#
|
||||
# See https://github.com/k2-fsa/icefall/issues/97
|
||||
# for more details
|
||||
unsorted_token_ids = graph_compiler.texts_to_ids(
|
||||
supervisions["text"]
|
||||
)
|
||||
att_loss = mmodel.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=unsorted_token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
||||
else:
|
||||
loss = ctc_loss
|
||||
|
@ -394,24 +394,16 @@ def compute_loss(
|
||||
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
||||
|
||||
if params.att_rate != 0.0:
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
token_ids = graph_compiler.texts_to_ids(supervisions["text"])
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if hasattr(model, "module"):
|
||||
att_loss = model.module.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
else:
|
||||
att_loss = model.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
mmodel = model.module if hasattr(model, "module") else model
|
||||
att_loss = mmodel.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
|
||||
else:
|
||||
loss = mmi_loss
|
||||
|
@ -394,24 +394,16 @@ def compute_loss(
|
||||
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
||||
|
||||
if params.att_rate != 0.0:
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
token_ids = graph_compiler.texts_to_ids(supervisions["text"])
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if hasattr(model, "module"):
|
||||
att_loss = model.module.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
else:
|
||||
att_loss = model.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
mmodel = model.module if hasattr(model, "module") else model
|
||||
att_loss = mmodel.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
|
||||
else:
|
||||
loss = mmi_loss
|
||||
|
Loading…
x
Reference in New Issue
Block a user