From 68b99f456f2b42b920674bb949de37f0ead8e062 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Fri, 7 Jun 2024 09:53:40 +0800 Subject: [PATCH] fix debug --- egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 6efee5e4d..7bdfc57dd 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -290,11 +290,9 @@ def decode_one_batch( ) generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)) - hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] - hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0] - # hyps = remove_punctuation(hyps) - # hyps = to_simple(hyps) - # hyps = [params.normalizer.normalize(hyp) for hyp in hyps] + hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False) + print(hyps) texts = batch["supervisions"]["text"] for i, text in enumerate(texts): @@ -381,7 +379,7 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] - # assert len(hyps) == len(texts) + assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_text = normalize_text_alimeeting(ref_text) ref_words = ref_text.split()