From 42b437bea630ae5279a1ce1fadd6c6b4de8f2f1b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 29 Oct 2021 13:46:41 +0800 Subject: [PATCH] Use pre-sorted text to generate token ids for attention decoder. (#98) * Use pre-sorted text to generate token ids for attention decoder. See https://github.com/k2-fsa/icefall/issues/97 for more details. * Fix typos. --- egs/librispeech/ASR/conformer_ctc/train.py | 35 ++++++++++--------- .../ASR/conformer_mmi/train-with-attention.py | 26 +++++--------- egs/librispeech/ASR/conformer_mmi/train.py | 26 +++++--------- 3 files changed, 37 insertions(+), 50 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 223c8d993..1384204dd 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -362,22 +362,25 @@ def compute_loss( if params.att_rate != 0.0: with torch.set_grad_enabled(is_training): - if hasattr(model, "module"): - att_loss = model.module.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - else: - att_loss = model.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss else: loss = ctc_loss diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index 8b8994059..011dadd73 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -394,24 +394,16 @@ def compute_loss( mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts) if params.att_rate != 0.0: - token_ids = graph_compiler.texts_to_ids(texts) + token_ids = graph_compiler.texts_to_ids(supervisions["text"]) with torch.set_grad_enabled(is_training): - if hasattr(model, "module"): - att_loss = model.module.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - else: - att_loss = model.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) + mmodel = model.module if hasattr(model, "module") else model + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss else: loss = mmi_loss diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 6580792ff..c36677762 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -394,24 +394,16 @@ def compute_loss( mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts) if params.att_rate != 0.0: - token_ids = graph_compiler.texts_to_ids(texts) + token_ids = graph_compiler.texts_to_ids(supervisions["text"]) with torch.set_grad_enabled(is_training): - if hasattr(model, "module"): - att_loss = model.module.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - else: - att_loss = model.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) + mmodel = model.module if hasattr(model, "module") else model + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss else: loss = mmi_loss