diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py index f7e436806..3a539c6ab 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py @@ -59,6 +59,7 @@ class SPEECH_LLM(nn.Module): llm: nn.Module, encoder_projector: nn.Module, codec_lm: nn.Module = None, + use_flash_attention: bool = False, ): super().__init__() self.encoder = encoder @@ -73,6 +74,7 @@ class SPEECH_LLM(nn.Module): self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size ) self.loss_fct = torch.nn.CrossEntropyLoss() + self.codec_lm_padding_side = "left" if use_flash_attention else "right" def _merge_input_ids_with_speech_features( self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None @@ -291,12 +293,13 @@ class SPEECH_LLM(nn.Module): speech_codec = torch.tensor( speech_codec, dtype=torch.int64, device=input_ids.device ) - # print(inputs_embeds[i, text_label_start_index], "2333 test") audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id audio_codes[i, text_label_start_index + delay_step + 1 : text_label_start_index + delay_step + 1 + len(speech_codec)] = speech_codec audio_labels[i, text_label_start_index + delay_step : text_label_start_index + delay_step + len(speech_codec)] = speech_codec audio_labels[i, text_label_start_index + delay_step + len(speech_codec)] = self.codec_lm.config.eos_token_id - + + if self.codec_lm_padding_side == "left": + pass audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes) diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py index 9823492bf..0d217cf36 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py @@ -783,6 +783,7 @@ def run(rank, world_size, args): llm, encoder_projector, codec_lm, + params.use_flash_attn, ) if params.pretrained_model_path: