diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index b51ebcfe3..f386bcdd0 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -297,7 +297,7 @@ def decode_one_batch( generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - + return {"beam-search": hyps} @@ -383,6 +383,8 @@ def decode_dataset( for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_text = normalize_text_alimeeting(ref_text) ref_words = ref_text.split() + print(f"ref: {ref_text}") + print(f"hyp: {''.join(hyp_words)}") this_batch.append((cut_id, ref_words, hyp_words)) results[lm_scale].extend(this_batch)