mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
fix batch_size>1 decoding bug
This commit is contained in:
parent
0c02da82ac
commit
458d697acc
@ -342,6 +342,7 @@ def decode_one_batch(
|
||||
# {"role": "user", "content": f"{last_questions[i]}"},
|
||||
{"role": "assistant", "content": ""}
|
||||
]
|
||||
print(f"message: {message}, batch_size {len(chat_rounds)}")
|
||||
messages.append(message)
|
||||
|
||||
input_ids, attention_mask = preprocess(messages, tokenizer)
|
||||
|
||||
@ -241,7 +241,7 @@ class SPEECH_LLM(nn.Module):
|
||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||
(
|
||||
inputs_embeds,
|
||||
_,
|
||||
attention_mask,
|
||||
_,
|
||||
_,
|
||||
) = self._merge_input_ids_with_speech_features(
|
||||
@ -249,6 +249,7 @@ class SPEECH_LLM(nn.Module):
|
||||
)
|
||||
generated_ids = self.llm.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=kwargs.get("max_new_tokens", 1024),
|
||||
num_beams=kwargs.get("num_beams", 1),
|
||||
do_sample=kwargs.get("do_sample", True),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user