mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
fix batch_size>1 decoding bug
This commit is contained in:
parent
0c02da82ac
commit
458d697acc
@ -342,6 +342,7 @@ def decode_one_batch(
|
|||||||
# {"role": "user", "content": f"{last_questions[i]}"},
|
# {"role": "user", "content": f"{last_questions[i]}"},
|
||||||
{"role": "assistant", "content": ""}
|
{"role": "assistant", "content": ""}
|
||||||
]
|
]
|
||||||
|
print(f"message: {message}, batch_size {len(chat_rounds)}")
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
|
||||||
input_ids, attention_mask = preprocess(messages, tokenizer)
|
input_ids, attention_mask = preprocess(messages, tokenizer)
|
||||||
|
|||||||
@ -241,7 +241,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
|
||||||
(
|
(
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
_,
|
attention_mask,
|
||||||
_,
|
_,
|
||||||
_,
|
_,
|
||||||
) = self._merge_input_ids_with_speech_features(
|
) = self._merge_input_ids_with_speech_features(
|
||||||
@ -249,6 +249,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
)
|
)
|
||||||
generated_ids = self.llm.generate(
|
generated_ids = self.llm.generate(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
max_new_tokens=kwargs.get("max_new_tokens", 1024),
|
max_new_tokens=kwargs.get("max_new_tokens", 1024),
|
||||||
num_beams=kwargs.get("num_beams", 1),
|
num_beams=kwargs.get("num_beams", 1),
|
||||||
do_sample=kwargs.get("do_sample", True),
|
do_sample=kwargs.get("do_sample", True),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user