fix batch_size>1 decoding bug

This commit is contained in:
root 2025-04-15 13:41:33 +00:00
parent 0c02da82ac
commit 458d697acc
2 changed files with 3 additions and 1 deletions

View File

@ -342,6 +342,7 @@ def decode_one_batch(
# {"role": "user", "content": f"{last_questions[i]}"}, # {"role": "user", "content": f"{last_questions[i]}"},
{"role": "assistant", "content": ""} {"role": "assistant", "content": ""}
] ]
print(f"message: {message}, batch_size {len(chat_rounds)}")
messages.append(message) messages.append(message)
input_ids, attention_mask = preprocess(messages, tokenizer) input_ids, attention_mask = preprocess(messages, tokenizer)

View File

@ -241,7 +241,7 @@ class SPEECH_LLM(nn.Module):
inputs_embeds = self.llm.get_input_embeddings()(input_ids) inputs_embeds = self.llm.get_input_embeddings()(input_ids)
( (
inputs_embeds, inputs_embeds,
_, attention_mask,
_, _,
_, _,
) = self._merge_input_ids_with_speech_features( ) = self._merge_input_ids_with_speech_features(
@ -249,6 +249,7 @@ class SPEECH_LLM(nn.Module):
) )
generated_ids = self.llm.generate( generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=kwargs.get("max_new_tokens", 1024), max_new_tokens=kwargs.get("max_new_tokens", 1024),
num_beams=kwargs.get("num_beams", 1), num_beams=kwargs.get("num_beams", 1),
do_sample=kwargs.get("do_sample", True), do_sample=kwargs.get("do_sample", True),