mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
fix padding side
This commit is contained in:
parent
bdb60f6ddc
commit
b305cdacc0
@ -59,6 +59,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
llm: nn.Module,
|
llm: nn.Module,
|
||||||
encoder_projector: nn.Module,
|
encoder_projector: nn.Module,
|
||||||
codec_lm: nn.Module = None,
|
codec_lm: nn.Module = None,
|
||||||
|
use_flash_attention: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
@ -73,6 +74,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
||||||
)
|
)
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss()
|
self.loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
|
self.codec_lm_padding_side = "left" if use_flash_attention else "right"
|
||||||
|
|
||||||
def _merge_input_ids_with_speech_features(
|
def _merge_input_ids_with_speech_features(
|
||||||
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None
|
||||||
@ -291,12 +293,13 @@ class SPEECH_LLM(nn.Module):
|
|||||||
speech_codec = torch.tensor(
|
speech_codec = torch.tensor(
|
||||||
speech_codec, dtype=torch.int64, device=input_ids.device
|
speech_codec, dtype=torch.int64, device=input_ids.device
|
||||||
)
|
)
|
||||||
# print(inputs_embeds[i, text_label_start_index], "2333 test")
|
|
||||||
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
|
audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id
|
||||||
audio_codes[i, text_label_start_index + delay_step + 1 : text_label_start_index + delay_step + 1 + len(speech_codec)] = speech_codec
|
audio_codes[i, text_label_start_index + delay_step + 1 : text_label_start_index + delay_step + 1 + len(speech_codec)] = speech_codec
|
||||||
audio_labels[i, text_label_start_index + delay_step : text_label_start_index + delay_step + len(speech_codec)] = speech_codec
|
audio_labels[i, text_label_start_index + delay_step : text_label_start_index + delay_step + len(speech_codec)] = speech_codec
|
||||||
audio_labels[i, text_label_start_index + delay_step + len(speech_codec)] = self.codec_lm.config.eos_token_id
|
audio_labels[i, text_label_start_index + delay_step + len(speech_codec)] = self.codec_lm.config.eos_token_id
|
||||||
|
|
||||||
|
if self.codec_lm_padding_side == "left":
|
||||||
|
pass
|
||||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
||||||
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
||||||
|
|
||||||
|
@ -783,6 +783,7 @@ def run(rank, world_size, args):
|
|||||||
llm,
|
llm,
|
||||||
encoder_projector,
|
encoder_projector,
|
||||||
codec_lm,
|
codec_lm,
|
||||||
|
params.use_flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.pretrained_model_path:
|
if params.pretrained_model_path:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user