change padding side name

This commit is contained in:
Yuekai Zhang 2025-04-21 17:10:25 +08:00
parent 7db40052d6
commit 09d81b44a7
2 changed files with 3 additions and 3 deletions

View File

@ -59,7 +59,7 @@ class SPEECH_LLM(nn.Module):
llm: nn.Module, llm: nn.Module,
encoder_projector: nn.Module, encoder_projector: nn.Module,
codec_lm: nn.Module = None, codec_lm: nn.Module = None,
use_flash_attention: bool = False, codec_lm_padding_side: str = "left",
): ):
super().__init__() super().__init__()
self.encoder = encoder self.encoder = encoder
@ -74,7 +74,7 @@ class SPEECH_LLM(nn.Module):
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
) )
self.loss_fct = torch.nn.CrossEntropyLoss() self.loss_fct = torch.nn.CrossEntropyLoss()
self.codec_lm_padding_side = "left" if use_flash_attention else "right" self.codec_lm_padding_side = codec_lm_padding_side
def _merge_input_ids_with_speech_features( def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None

View File

@ -793,7 +793,7 @@ def run(rank, world_size, args):
llm, llm,
encoder_projector, encoder_projector,
codec_lm, codec_lm,
params.use_flash_attn, codec_lm_padding_side= "left" if params.use_flash_attn else "right",
) )
if params.pretrained_model_path: if params.pretrained_model_path: