From 458d697accdceea6bc86bc3978000cb88c69044a Mon Sep 17 00:00:00 2001 From: root Date: Tue, 15 Apr 2025 13:41:33 +0000 Subject: [PATCH] fix batch_size>1 decoding bug --- egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py | 1 + egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py index 3feef8f1c..66ccd9974 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py @@ -342,6 +342,7 @@ def decode_one_batch( # {"role": "user", "content": f"{last_questions[i]}"}, {"role": "assistant", "content": ""} ] + print(f"message: {message}, batch_size {len(chat_rounds)}") messages.append(message) input_ids, attention_mask = preprocess(messages, tokenizer) diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py index 5126a5d34..55541f03e 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py @@ -241,7 +241,7 @@ class SPEECH_LLM(nn.Module): inputs_embeds = self.llm.get_input_embeddings()(input_ids) ( inputs_embeds, - _, + attention_mask, _, _, ) = self._merge_input_ids_with_speech_features( @@ -249,6 +249,7 @@ class SPEECH_LLM(nn.Module): ) generated_ids = self.llm.generate( inputs_embeds=inputs_embeds, + attention_mask=attention_mask, max_new_tokens=kwargs.get("max_new_tokens", 1024), num_beams=kwargs.get("num_beams", 1), do_sample=kwargs.get("do_sample", True),