add flash attn support

This commit is contained in:
Yuekai Zhang 2025-04-21 14:54:28 +08:00
parent b305cdacc0
commit 7db40052d6
2 changed files with 69 additions and 11 deletions

View File

@ -287,19 +287,50 @@ class SPEECH_LLM(nn.Module):
device=input_ids.device
)
audio_labels = audio_codes.clone()
total_len = audio_codes.shape[1]
for i, speech_codec in enumerate(speech_codec_ids):
text_label_start_index = text_label_start_index_list[i]
speech_codec = torch.tensor(
speech_codec, dtype=torch.int64, device=input_ids.device
)
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
audio_codes[i, text_label_start_index + delay_step + 1 : text_label_start_index + delay_step + 1 + len(speech_codec)] = speech_codec
audio_labels[i, text_label_start_index + delay_step : text_label_start_index + delay_step + len(speech_codec)] = speech_codec
audio_labels[i, text_label_start_index + delay_step + len(speech_codec)] = self.codec_lm.config.eos_token_id
speech_codec_len = len(speech_codec)
# Calculate lengths of non-padding content
codes_len = text_label_start_index + delay_step + 1 + speech_codec_len
# Actual label content length (speech codec tokens + eos token)
labels_actual_content_len = speech_codec_len + 1
if self.codec_lm_padding_side == "right":
# Fill audio_codes (right padding)
codes_end_idx = text_label_start_index + delay_step + 1 + speech_codec_len
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec
# Fill audio_labels (right padding)
labels_start_idx = text_label_start_index + delay_step
labels_speech_end_idx = labels_start_idx + speech_codec_len
audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec
audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id
elif self.codec_lm_padding_side == "left":
# Calculate start indices for left padding (shifting content to the right)
codes_start_idx = total_len - codes_len
labels_start_idx = total_len - labels_actual_content_len # Start index for the actual label content
# Fill audio_codes (left padding)
codes_speech_start_idx = codes_start_idx + text_label_start_index + delay_step + 1
audio_codes[i, codes_start_idx : codes_speech_start_idx] = self.codec_lm.config.bos_token_id # mask token_id
audio_codes[i, codes_speech_start_idx : total_len] = speech_codec
# Fill audio_labels (left padding)
labels_speech_end_idx = labels_start_idx + speech_codec_len
# Note: The beginning part remains pad_token_id
audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec
audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id
else:
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
if self.codec_lm_padding_side == "left":
pass
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
@ -308,7 +339,24 @@ class SPEECH_LLM(nn.Module):
text_input_embeds = inputs_embeds + text_last_hidden_outputs
text_input_embeds = self.speech_token_projector(text_input_embeds)
audio_embeddings[:, : text_input_embeds.shape[1]] += text_input_embeds
T_merged = text_input_embeds.shape[1]
T_audio = audio_embeddings.shape[1]
if self.codec_lm_padding_side == "right":
# Add to the beginning for right padding
audio_embeddings[:, :T_merged] += text_input_embeds
elif self.codec_lm_padding_side == "left":
# Need to add to the shifted position for left padding
# Calculate the length of the non-padded sequence for each item
seq_lens = audio_attention_mask.sum(dim=1) # Shape (B)
for i in range(audio_embeddings.shape[0]):
item_len = seq_lens[i].item() # Get the non-padded length for item i
start_idx_content = T_audio - item_len # Start index of the content for item i
end_idx_target = start_idx_content + T_merged # End index of the target slice within the content
# Add the text_input_embeds to the calculated slice
audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i]
else:
raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}")
speech_outputs = self.codec_lm(
attention_mask=audio_attention_mask,

View File

@ -383,7 +383,7 @@ def compute_loss(
last_questions = [question.split('<USER>: ')[-1].strip() for question in questions_with_history]
history_contexts = [question.rsplit('<USER>:', 1)[0].strip() for question in questions_with_history]
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。<USER>: 告诉我如何烹饪鸡肉
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为他是一个有善心的人。
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。
messages = []
for i, total_round in enumerate(chat_rounds):
@ -730,11 +730,14 @@ def run(rank, world_size, args):
if params.enable_speech_output:
# Determine attn_implementation and torch_dtype based on use_flash_attn
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported
else:
attn_implementation = "eager"
torch_dtype = torch.float16
torch_dtype = torch.float16
# codec_lm = AutoModelForCausalLM.from_pretrained(
# params.llm_path_or_name,
# attn_implementation=attn_implementation,
@ -750,7 +753,14 @@ def run(rank, world_size, args):
intermediate_size=2048,
max_position_embeddings=4096,
)
codec_lm = Qwen2ForCausalLM(config=config)
# codec_lm = Qwen2ForCausalLM(config=config)
# Pass attn_implementation and torch_dtype to the constructor
# Use AutoModelForCausalLM.from_config for more generality
codec_lm = AutoModelForCausalLM.from_config(
config=config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype
)
# cosyvoice2_token_size = 6561
codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size
@ -829,7 +839,7 @@ def run(rank, world_size, args):
return False
# cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]
codec_len = len(c.custom["answer_cosyvoice_speech_token"])
if codec_len > 2048:
if codec_len > 2200:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}"
)