diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 79b6a6097..6efee5e4d 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -296,6 +296,10 @@ def decode_one_batch( # hyps = to_simple(hyps) # hyps = [params.normalizer.normalize(hyp) for hyp in hyps] print(hyps) + texts = batch["supervisions"]["text"] + for i, text in enumerate(texts): + print(f"ref: {text}") + print(f"hyp: {hyps[i]}") return {"beam-search": hyps} @@ -476,7 +480,8 @@ def main(): if params.use_flash_attn: attn_implementation = "flash_attention_2" - torch_dtype=torch.bfloat16 + # torch_dtype=torch.bfloat16 + torch_dtype=torch.float16 else: attn_implementation = "eager"