diff --git a/egs/speech_llm/ASR_LLM/debug.sh b/egs/speech_llm/ASR_LLM/debug.sh index 53f411870..255de44cd 100755 --- a/egs/speech_llm/ASR_LLM/debug.sh +++ b/egs/speech_llm/ASR_LLM/debug.sh @@ -17,8 +17,8 @@ export CUDA_VISIBLE_DEVICES=0,1 python3 ./whisper_llm_zh/decode.py \ --max-duration 80 \ - --exp-dir ./whisper_llm_zh/exp_test \ - --speech-encoder-path-or-name tiny \ + --exp-dir ./whisper_llm_zh/exp_qwen_0.5B \ + --speech-encoder-path-or-name /mnt/samsung-t7/yuekai/asr/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --llm-path-or-name Qwen/Qwen1.5-0.5B-Chat \ --epoch 1 --avg 1 \ --manifest-dir data/fbank \ diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 54a083c1e..79b6a6097 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -291,7 +291,7 @@ def decode_one_batch( generated_ids = model.decode(feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] - + hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0] # hyps = remove_punctuation(hyps) # hyps = to_simple(hyps) # hyps = [params.normalizer.normalize(hyp) for hyp in hyps] @@ -377,7 +377,7 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] - assert len(hyps) == len(texts) + # assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): ref_text = normalize_text_alimeeting(ref_text) ref_words = ref_text.split() diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py index 796cb2c9d..5d4585b11 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py @@ -169,7 +169,8 @@ class SPEECH_LLM(nn.Module): eos_token_id=self.llm.config.eos_token_id, pad_token_id=self.llm.config.pad_token_id ) - generated_ids = [ - output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids) - ] + # print(generated_ids, input_ids) + # generated_ids = [ + # output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids) + # ] return generated_ids \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt b/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt index f86357d68..bde73601f 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt @@ -1,7 +1,8 @@ k2 kaldialign git+https://github.com/lhotse-speech/lhotse -# sentencepiece +sentencepiece +pypinyin tensorboard librosa # git+https://github.com/yuekaizhang/whisper.git