diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 6efee5e4d..7bdfc57dd 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -290,11 +290,9 @@ def decode_one_batch( ) generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)) - hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] - hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0] - # hyps = remove_punctuation(hyps) - # hyps = to_simple(hyps) - # hyps = [params.normalizer.normalize(hyp) for hyp in hyps] + hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False) + print(hyps) texts = batch["supervisions"]["text"] for i, text in enumerate(texts): @@ -381,7 +379,7 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] - # assert len(hyps) == len(texts) + assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_text = normalize_text_alimeeting(ref_text) ref_words = ref_text.split()