mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
change padding side name
This commit is contained in:
parent
7db40052d6
commit
09d81b44a7
@ -59,7 +59,7 @@ class SPEECH_LLM(nn.Module):
|
||||
llm: nn.Module,
|
||||
encoder_projector: nn.Module,
|
||||
codec_lm: nn.Module = None,
|
||||
use_flash_attention: bool = False,
|
||||
codec_lm_padding_side: str = "left",
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
@ -74,7 +74,7 @@ class SPEECH_LLM(nn.Module):
|
||||
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
||||
)
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss()
|
||||
self.codec_lm_padding_side = "left" if use_flash_attention else "right"
|
||||
self.codec_lm_padding_side = codec_lm_padding_side
|
||||
|
||||
def _merge_input_ids_with_speech_features(
|
||||
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
||||
|
@ -793,7 +793,7 @@ def run(rank, world_size, args):
|
||||
llm,
|
||||
encoder_projector,
|
||||
codec_lm,
|
||||
params.use_flash_attn,
|
||||
codec_lm_padding_side= "left" if params.use_flash_attn else "right",
|
||||
)
|
||||
|
||||
if params.pretrained_model_path:
|
||||
|
Loading…
x
Reference in New Issue
Block a user