mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
add flash attn support
This commit is contained in:
parent
b305cdacc0
commit
7db40052d6
@ -287,19 +287,50 @@ class SPEECH_LLM(nn.Module):
|
||||
device=input_ids.device
|
||||
)
|
||||
audio_labels = audio_codes.clone()
|
||||
total_len = audio_codes.shape[1]
|
||||
|
||||
for i, speech_codec in enumerate(speech_codec_ids):
|
||||
text_label_start_index = text_label_start_index_list[i]
|
||||
speech_codec = torch.tensor(
|
||||
speech_codec, dtype=torch.int64, device=input_ids.device
|
||||
)
|
||||
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
|
||||
audio_codes[i, text_label_start_index + delay_step + 1 : text_label_start_index + delay_step + 1 + len(speech_codec)] = speech_codec
|
||||
audio_labels[i, text_label_start_index + delay_step : text_label_start_index + delay_step + len(speech_codec)] = speech_codec
|
||||
audio_labels[i, text_label_start_index + delay_step + len(speech_codec)] = self.codec_lm.config.eos_token_id
|
||||
speech_codec_len = len(speech_codec)
|
||||
|
||||
# Calculate lengths of non-padding content
|
||||
codes_len = text_label_start_index + delay_step + 1 + speech_codec_len
|
||||
# Actual label content length (speech codec tokens + eos token)
|
||||
labels_actual_content_len = speech_codec_len + 1
|
||||
|
||||
if self.codec_lm_padding_side == "right":
|
||||
# Fill audio_codes (right padding)
|
||||
codes_end_idx = text_label_start_index + delay_step + 1 + speech_codec_len
|
||||
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
|
||||
audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec
|
||||
|
||||
# Fill audio_labels (right padding)
|
||||
labels_start_idx = text_label_start_index + delay_step
|
||||
labels_speech_end_idx = labels_start_idx + speech_codec_len
|
||||
audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec
|
||||
audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id
|
||||
|
||||
elif self.codec_lm_padding_side == "left":
|
||||
# Calculate start indices for left padding (shifting content to the right)
|
||||
codes_start_idx = total_len - codes_len
|
||||
labels_start_idx = total_len - labels_actual_content_len # Start index for the actual label content
|
||||
|
||||
# Fill audio_codes (left padding)
|
||||
codes_speech_start_idx = codes_start_idx + text_label_start_index + delay_step + 1
|
||||
audio_codes[i, codes_start_idx : codes_speech_start_idx] = self.codec_lm.config.bos_token_id # mask token_id
|
||||
audio_codes[i, codes_speech_start_idx : total_len] = speech_codec
|
||||
|
||||
# Fill audio_labels (left padding)
|
||||
labels_speech_end_idx = labels_start_idx + speech_codec_len
|
||||
# Note: The beginning part remains pad_token_id
|
||||
audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec
|
||||
audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id
|
||||
else:
|
||||
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
|
||||
|
||||
if self.codec_lm_padding_side == "left":
|
||||
pass
|
||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
||||
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
||||
|
||||
@ -308,7 +339,24 @@ class SPEECH_LLM(nn.Module):
|
||||
text_input_embeds = inputs_embeds + text_last_hidden_outputs
|
||||
text_input_embeds = self.speech_token_projector(text_input_embeds)
|
||||
|
||||
audio_embeddings[:, : text_input_embeds.shape[1]] += text_input_embeds
|
||||
T_merged = text_input_embeds.shape[1]
|
||||
T_audio = audio_embeddings.shape[1]
|
||||
|
||||
if self.codec_lm_padding_side == "right":
|
||||
# Add to the beginning for right padding
|
||||
audio_embeddings[:, :T_merged] += text_input_embeds
|
||||
elif self.codec_lm_padding_side == "left":
|
||||
# Need to add to the shifted position for left padding
|
||||
# Calculate the length of the non-padded sequence for each item
|
||||
seq_lens = audio_attention_mask.sum(dim=1) # Shape (B)
|
||||
for i in range(audio_embeddings.shape[0]):
|
||||
item_len = seq_lens[i].item() # Get the non-padded length for item i
|
||||
start_idx_content = T_audio - item_len # Start index of the content for item i
|
||||
end_idx_target = start_idx_content + T_merged # End index of the target slice within the content
|
||||
# Add the text_input_embeds to the calculated slice
|
||||
audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i]
|
||||
else:
|
||||
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
|
||||
|
||||
speech_outputs = self.codec_lm(
|
||||
attention_mask=audio_attention_mask,
|
||||
|
||||
@ -383,7 +383,7 @@ def compute_loss(
|
||||
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
|
||||
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
|
||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。<USER>: 告诉我如何烹饪鸡肉
|
||||
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为“他是一个有善心的人。
|
||||
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。
|
||||
|
||||
messages = []
|
||||
for i, total_round in enumerate(chat_rounds):
|
||||
@ -730,11 +730,14 @@ def run(rank, world_size, args):
|
||||
|
||||
|
||||
if params.enable_speech_output:
|
||||
# Determine attn_implementation and torch_dtype based on use_flash_attn
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
|
||||
else:
|
||||
attn_implementation = "eager"
|
||||
torch_dtype = torch.float16
|
||||
torch_dtype = torch.float16
|
||||
|
||||
# codec_lm = AutoModelForCausalLM.from_pretrained(
|
||||
# params.llm_path_or_name,
|
||||
# attn_implementation=attn_implementation,
|
||||
@ -750,7 +753,14 @@ def run(rank, world_size, args):
|
||||
intermediate_size=2048,
|
||||
max_position_embeddings=4096,
|
||||
)
|
||||
codec_lm = Qwen2ForCausalLM(config=config)
|
||||
# codec_lm = Qwen2ForCausalLM(config=config)
|
||||
# Pass attn_implementation and torch_dtype to the constructor
|
||||
# Use AutoModelForCausalLM.from_config for more generality
|
||||
codec_lm = AutoModelForCausalLM.from_config(
|
||||
config=config,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype
|
||||
)
|
||||
# cosyvoice2_token_size = 6561
|
||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||
codec_lm.vocab_size = codec_vocab_size
|
||||
@ -829,7 +839,7 @@ def run(rank, world_size, args):
|
||||
return False
|
||||
# cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]
|
||||
codec_len = len(c.custom["answer_cosyvoice_speech_token"])
|
||||
if codec_len > 2048:
|
||||
if codec_len > 2200:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user