fix padding side

This commit is contained in:
root 2025-04-21 06:23:10 +00:00
parent bdb60f6ddc
commit b305cdacc0
2 changed files with 6 additions and 2 deletions

View File

@ -59,6 +59,7 @@ class SPEECH_LLM(nn.Module):
llm: nn.Module, llm: nn.Module,
encoder_projector: nn.Module, encoder_projector: nn.Module,
codec_lm: nn.Module = None, codec_lm: nn.Module = None,
use_flash_attention: bool = False,
): ):
super().__init__() super().__init__()
self.encoder = encoder self.encoder = encoder
@ -73,6 +74,7 @@ class SPEECH_LLM(nn.Module):
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
) )
self.loss_fct = torch.nn.CrossEntropyLoss() self.loss_fct = torch.nn.CrossEntropyLoss()
self.codec_lm_padding_side = "left" if use_flash_attention else "right"
def _merge_input_ids_with_speech_features( def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
@ -291,12 +293,13 @@ class SPEECH_LLM(nn.Module):
speech_codec = torch.tensor( speech_codec = torch.tensor(
speech_codec, dtype=torch.int64, device=input_ids.device speech_codec, dtype=torch.int64, device=input_ids.device
) )
# print(inputs_embeds[i, text_label_start_index], "2333 test")
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
audio_codes[i, text_label_start_index + delay_step + 1 : text_label_start_index + delay_step + 1 + len(speech_codec)] = speech_codec audio_codes[i, text_label_start_index + delay_step + 1 : text_label_start_index + delay_step + 1 + len(speech_codec)] = speech_codec
audio_labels[i, text_label_start_index + delay_step : text_label_start_index + delay_step + len(speech_codec)] = speech_codec audio_labels[i, text_label_start_index + delay_step : text_label_start_index + delay_step + len(speech_codec)] = speech_codec
audio_labels[i, text_label_start_index + delay_step + len(speech_codec)] = self.codec_lm.config.eos_token_id audio_labels[i, text_label_start_index + delay_step + len(speech_codec)] = self.codec_lm.config.eos_token_id
if self.codec_lm_padding_side == "left":
pass
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes) audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)

View File

@ -783,6 +783,7 @@ def run(rank, world_size, args):
llm, llm,
encoder_projector, encoder_projector,
codec_lm, codec_lm,
params.use_flash_attn,
) )
if params.pretrained_model_path: if params.pretrained_model_path: