From 7db40052d6a682b8f0743ca3289da13f98137f92 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Mon, 21 Apr 2025 14:54:28 +0800 Subject: [PATCH] add flash attn support --- .../SPEECH2SPEECH/slam_omni/model.py | 62 ++++++++++++++++--- .../SPEECH2SPEECH/slam_omni/train.py | 18 ++++-- 2 files changed, 69 insertions(+), 11 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py index 3a539c6ab..22f627ecc 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py @@ -287,19 +287,50 @@ class SPEECH_LLM(nn.Module): device=input_ids.device ) audio_labels = audio_codes.clone() + total_len = audio_codes.shape[1] for i, speech_codec in enumerate(speech_codec_ids): text_label_start_index = text_label_start_index_list[i] speech_codec = torch.tensor( speech_codec, dtype=torch.int64, device=input_ids.device ) - audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id - audio_codes[i, text_label_start_index + delay_step + 1 : text_label_start_index + delay_step + 1 + len(speech_codec)] = speech_codec - audio_labels[i, text_label_start_index + delay_step : text_label_start_index + delay_step + len(speech_codec)] = speech_codec - audio_labels[i, text_label_start_index + delay_step + len(speech_codec)] = self.codec_lm.config.eos_token_id + speech_codec_len = len(speech_codec) + + # Calculate lengths of non-padding content + codes_len = text_label_start_index + delay_step + 1 + speech_codec_len + # Actual label content length (speech codec tokens + eos token) + labels_actual_content_len = speech_codec_len + 1 + + if self.codec_lm_padding_side == "right": + # Fill audio_codes (right padding) + codes_end_idx = text_label_start_index + delay_step + 1 + speech_codec_len + audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id + audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec + + # Fill audio_labels (right padding) + labels_start_idx = text_label_start_index + delay_step + labels_speech_end_idx = labels_start_idx + speech_codec_len + audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec + audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id + + elif self.codec_lm_padding_side == "left": + # Calculate start indices for left padding (shifting content to the right) + codes_start_idx = total_len - codes_len + labels_start_idx = total_len - labels_actual_content_len # Start index for the actual label content + + # Fill audio_codes (left padding) + codes_speech_start_idx = codes_start_idx + text_label_start_index + delay_step + 1 + audio_codes[i, codes_start_idx : codes_speech_start_idx] = self.codec_lm.config.bos_token_id # mask token_id + audio_codes[i, codes_speech_start_idx : total_len] = speech_codec + + # Fill audio_labels (left padding) + labels_speech_end_idx = labels_start_idx + speech_codec_len + # Note: The beginning part remains pad_token_id + audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec + audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id + else: + raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}") - if self.codec_lm_padding_side == "left": - pass audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes) @@ -308,7 +339,24 @@ class SPEECH_LLM(nn.Module): text_input_embeds = inputs_embeds + text_last_hidden_outputs text_input_embeds = self.speech_token_projector(text_input_embeds) - audio_embeddings[:, : text_input_embeds.shape[1]] += text_input_embeds + T_merged = text_input_embeds.shape[1] + T_audio = audio_embeddings.shape[1] + + if self.codec_lm_padding_side == "right": + # Add to the beginning for right padding + audio_embeddings[:, :T_merged] += text_input_embeds + elif self.codec_lm_padding_side == "left": + # Need to add to the shifted position for left padding + # Calculate the length of the non-padded sequence for each item + seq_lens = audio_attention_mask.sum(dim=1) # Shape (B) + for i in range(audio_embeddings.shape[0]): + item_len = seq_lens[i].item() # Get the non-padded length for item i + start_idx_content = T_audio - item_len # Start index of the content for item i + end_idx_target = start_idx_content + T_merged # End index of the target slice within the content + # Add the text_input_embeds to the calculated slice + audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i] + else: + raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}") speech_outputs = self.codec_lm( attention_mask=audio_attention_mask, diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py index 0d217cf36..143c10c68 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py @@ -383,7 +383,7 @@ def compute_loss( last_questions = [question.split(': ')[-1].strip() for question in questions_with_history] history_contexts = [question.rsplit(':', 1)[0].strip() for question in questions_with_history] # USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。: 告诉我如何烹饪鸡肉 - # : 对以下句子进行鉴赏:他心地善良。输出结果为“他是一个有善心的人。 + # : 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。 messages = [] for i, total_round in enumerate(chat_rounds): @@ -730,11 +730,14 @@ def run(rank, world_size, args): if params.enable_speech_output: + # Determine attn_implementation and torch_dtype based on use_flash_attn if params.use_flash_attn: attn_implementation = "flash_attention_2" + torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported else: attn_implementation = "eager" - torch_dtype = torch.float16 + torch_dtype = torch.float16 + # codec_lm = AutoModelForCausalLM.from_pretrained( # params.llm_path_or_name, # attn_implementation=attn_implementation, @@ -750,7 +753,14 @@ def run(rank, world_size, args): intermediate_size=2048, max_position_embeddings=4096, ) - codec_lm = Qwen2ForCausalLM(config=config) + # codec_lm = Qwen2ForCausalLM(config=config) + # Pass attn_implementation and torch_dtype to the constructor + # Use AutoModelForCausalLM.from_config for more generality + codec_lm = AutoModelForCausalLM.from_config( + config=config, + attn_implementation=attn_implementation, + torch_dtype=torch_dtype + ) # cosyvoice2_token_size = 6561 codec_lm.resize_token_embeddings(codec_vocab_size) codec_lm.vocab_size = codec_vocab_size @@ -829,7 +839,7 @@ def run(rank, world_size, args): return False # cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"] codec_len = len(c.custom["answer_cosyvoice_speech_token"]) - if codec_len > 2048: + if codec_len > 2200: logging.warning( f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}" )