diff --git a/egs/librispeech/ASR/conformer_ctc2/.decode.py.swp b/egs/librispeech/ASR/conformer_ctc2/.decode.py.swp index 7f6616eff..b47ef3dee 100644 Binary files a/egs/librispeech/ASR/conformer_ctc2/.decode.py.swp and b/egs/librispeech/ASR/conformer_ctc2/.decode.py.swp differ diff --git a/egs/librispeech/ASR/conformer_ctc2/.transformer.py.swp b/egs/librispeech/ASR/conformer_ctc2/.transformer.py.swp index 9311babfd..6634a7b61 100644 Binary files a/egs/librispeech/ASR/conformer_ctc2/.transformer.py.swp and b/egs/librispeech/ASR/conformer_ctc2/.transformer.py.swp differ diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index b3d0e6038..8b38b0908 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -415,6 +415,14 @@ def decode_one_batch( return {key: hyps} if params.method == "greedy-search": + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + warmup=warmup, + ) hyps = greedy_search(nnet_output, memory_key_padding_mask) # hyps is a list of str, e.g., ['xxx yyy zzz', ...]