fix batch_size>1 decoding bug

This commit is contained in:
root 2025-04-15 13:41:33 +00:00
parent 0c02da82ac
commit 458d697acc
2 changed files with 3 additions and 1 deletions

View File

@ -342,6 +342,7 @@ def decode_one_batch(
# {"role": "user", "content": f"{last_questions[i]}"},
{"role": "assistant", "content": ""}
]
print(f"message: {message}, batch_size {len(chat_rounds)}")
messages.append(message)
input_ids, attention_mask = preprocess(messages, tokenizer)

View File

@ -241,7 +241,7 @@ class SPEECH_LLM(nn.Module):
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
(
inputs_embeds,
_,
attention_mask,
_,
_,
) = self._merge_input_ids_with_speech_features(
@ -249,6 +249,7 @@ class SPEECH_LLM(nn.Module):
)
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=kwargs.get("max_new_tokens", 1024),
num_beams=kwargs.get("num_beams", 1),
do_sample=kwargs.get("do_sample", True),