diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 96f6134f1..f5ffe026e 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -239,7 +239,7 @@ def decode_one_batch( is a 3-gram LM, while this G is a 4-gram LM. Returns: Return the decoding result. See above description for the format of - the returned dict. + the returned dict. Note: If it decodes to nothing, then return None. """ if HLG is not None: device = HLG.device @@ -392,8 +392,7 @@ def decode_one_batch( hyps = [[word_table[i] for i in ids] for ids in hyps] ans[lm_scale_str] = hyps else: - for lm_scale in lm_scale_list: - ans[f"{lm_scale}"] = [[] * lattice.shape[0]] + ans = None return ans @@ -467,16 +466,29 @@ def decode_dataset( eos_id=eos_id, ) - for lm_scale, hyps in hyps_dict.items(): + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" this_batch = [] - assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): + hyp_words = [] + for ref_text in texts: ref_words = ref_text.split() this_batch.append((ref_words, hyp_words)) - results[lm_scale].extend(this_batch) + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) - num_cuts += len(batch["supervisions"]["text"]) + num_cuts += len(texts) if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}"