diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index a1c31a252..e0a2fa507 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -51,9 +51,10 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "stage 3: " + exp_dir=./slam_omni/exp_speech2speech_rerun python3 ./slam_omni/decode.py \ --max-duration 1 \ - --exp-dir slam_omni/exp_speech2speech_test_flash_attn \ + --exp-dir $exp_dir \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ --epoch 997 --avg 1 \ @@ -87,21 +88,23 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "stage 5: " - ngpu=2 - exp_dir=./slam_omni/exp_speech2speech_test_flash_attn -torchrun --nproc_per_node $ngpu ./slam_omni/train.py \ - --max-duration 40 \ - --enable-musan False \ - --exp-dir $exp_dir \ - --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ - --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ - --manifest-dir data/fbank \ - --deepspeed \ - --deepspeed_config ./slam_omni/ds_config_zero1.json \ - --use-flash-attn True \ - --pretrained-model-path $exp_dir/epoch-1-checkpoint-35000.pt/pytorch_model.bin \ - --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True - # --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \ - # --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \ + ngpu=8 + exp_dir=./slam_omni/exp_speech2speech_rerun + # exp_dir_new=./slam_omni/exp_s2s + torchrun --nproc_per_node $ngpu ./slam_omni/train.py \ + --max-duration 50 \ + --enable-musan False \ + --exp-dir $exp_dir \ + --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ + --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ + --manifest-dir data/fbank \ + --deepspeed \ + --deepspeed_config ./slam_omni/ds_config_zero1.json \ + --use-flash-attn True \ + --pretrained-model-path $exp_dir/epoch-1-checkpoint-15000.pt/pytorch_model.bin \ + --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-15000-sampler.pt \ + --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True + # --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \ + # --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \ fi \ No newline at end of file diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py index f23cd5f5d..54a8983df 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py @@ -579,7 +579,7 @@ def main(): # attn_implementation=attn_implementation, # torch_dtype=torch_dtype, # ) - codec_vocab_size = 8192 + codec_vocab_size = 4096 + 4 config = Qwen2Config( vocab_size=codec_vocab_size, hidden_size=1024, @@ -603,24 +603,25 @@ def main(): codec_lm.config.pad_token_id = codec_vocab_size - 1 codec_lm.config.eos_token_id = codec_vocab_size - 2 codec_lm.config.bos_token_id = codec_vocab_size - 3 - if params.use_lora: - lora_config = LoraConfig( - r=64, - lora_alpha=16, - target_modules=[ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "up_proj", - "gate_proj", - "down_proj", - ], - lora_dropout=0.05, - task_type="CAUSAL_LM", - ) - codec_lm = get_peft_model(codec_lm, lora_config) - codec_lm.print_trainable_parameters() + codec_lm.config.mask_token_id = codec_vocab_size - 4 + # if params.use_lora: + # lora_config = LoraConfig( + # r=64, + # lora_alpha=16, + # target_modules=[ + # "q_proj", + # "k_proj", + # "v_proj", + # "o_proj", + # "up_proj", + # "gate_proj", + # "down_proj", + # ], + # lora_dropout=0.05, + # task_type="CAUSAL_LM", + # ) + # codec_lm = get_peft_model(codec_lm, lora_config) + # codec_lm.print_trainable_parameters() else: codec_lm = None diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py index 1c110470e..c5f31226d 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py @@ -4,7 +4,7 @@ from transformers.trainer_pt_utils import LabelSmoother from typing import List, Tuple # Added for type hints from torchmetrics.classification import MulticlassAccuracy IGNORE_TOKEN_ID = LabelSmoother.ignore_index - +import logging class EncoderProjector(nn.Module): """ @@ -69,7 +69,7 @@ class SPEECH_LLM(nn.Module): self.codec_lm = codec_lm if self.codec_lm: self.speech_token_projector = nn.Linear( - self.llm.config.hidden_size, self.codec_lm.config.hidden_size + self.llm.config.hidden_size + self.llm.config.hidden_size, self.codec_lm.config.hidden_size ) self.codec_lm_head = nn.Linear( self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size @@ -274,110 +274,92 @@ class SPEECH_LLM(nn.Module): ) = self._merge_input_ids_with_speech_features( speech_features, inputs_embeds, input_ids, attention_mask, labels ) - - # get the label start_index in inputs_embeds from labels - text_label_start_index_list = [] + input_seq_len = attention_mask.sum(dim=1) # shape, B + text_label_start_index_list, text_input_start_index_list, input_question_len_list = [], [], [] for i in range(labels.shape[0]): - text_label_start_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0][0] - text_label_start_index_list.append(text_label_start_index) - # TODO1: check text_label_start_index position - print(i, input_ids[i], input_ids[i].shape, labels[i], labels[i].shape, text_label_start_index, labels[i][text_label_start_index]) + input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0] + input_embeds_start_index = input_embeds_valid_index[0] + text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0] + text_labels_start_index = text_labels_valid_index[0] + + assert input_seq_len[i] == input_embeds_valid_index[-1] - input_embeds_start_index + 1, f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}" + assert input_embeds_valid_index[-1] == text_labels_valid_index[-1], f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}" + input_question_len = text_labels_start_index - input_embeds_start_index + assert input_question_len + text_labels_valid_index[-1] - text_labels_start_index + 1 == input_seq_len[i] + text_label_start_index_list.append(text_labels_start_index) + text_input_start_index_list.append(input_embeds_start_index) + input_question_len_list.append(input_question_len) model_outputs = self.llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True ) text_loss = model_outputs.loss - + delay_step = 1 # prepare codec lm inputs - audio_codes_lens = torch.tensor( - [len(x) for x in speech_codec_ids], dtype=torch.int64, device=input_ids.device - ) - # print(audio_codes_lens, "audio_codes_lens") + audio_codes_lens = [len(x) + input_question_len_list[i] + delay_step + 1 for i, x in enumerate(speech_codec_ids)] max_len_speech_codec = max(audio_codes_lens) - delay_step = 2 - audio_codes = torch.full( - (inputs_embeds.shape[0], max_len_speech_codec + inputs_embeds.shape[1] + 1), - self.codec_lm.config.pad_token_id, + + if self.codec_lm_padding_side == "right": + audio_codes = [ + [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + for i, x in enumerate(speech_codec_ids) + ] + audio_labels = [ + [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + for i, x in enumerate(speech_codec_ids) + ] + elif self.codec_lm_padding_side == "left": + audio_codes = [ + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x + for i, x in enumerate(speech_codec_ids) + ] + audio_labels = [ + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] + for i, x in enumerate(speech_codec_ids) + ] + audio_codes = torch.tensor( + audio_codes, + dtype=torch.int64, + device=input_ids.device + ) + audio_labels = torch.tensor( + audio_labels, dtype=torch.int64, device=input_ids.device ) - audio_labels = audio_codes.clone() - total_len = audio_codes.shape[1] - for i, speech_codec in enumerate(speech_codec_ids): - text_label_start_index = text_label_start_index_list[i] - speech_codec = torch.tensor( - speech_codec, dtype=torch.int64, device=input_ids.device - ) - speech_codec_len = len(speech_codec) - - # Calculate lengths of non-padding content - codes_len = text_label_start_index + delay_step + 1 + speech_codec_len - # Actual label content length (speech codec tokens + eos token) - labels_actual_content_len = speech_codec_len + 1 - - if self.codec_lm_padding_side == "right": - # Fill audio_codes (right padding) - codes_end_idx = codes_len - audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id - audio_codes[i, text_label_start_index + delay_step + 1 : codes_end_idx] = speech_codec - - # Fill audio_labels (right padding) - labels_start_idx = text_label_start_index + delay_step - labels_speech_end_idx = labels_start_idx + speech_codec_len - audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec - audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id - - elif self.codec_lm_padding_side == "left": - # Calculate start indices for left padding (shifting content to the right) - codes_start_idx = total_len - codes_len - labels_start_idx = total_len - labels_actual_content_len # Start index for the actual label content - - # Fill audio_codes (left padding) - codes_speech_start_idx = codes_start_idx + text_label_start_index + delay_step + 1 - audio_codes[i, codes_start_idx : codes_speech_start_idx] = self.codec_lm.config.bos_token_id # mask token_id - audio_codes[i, codes_speech_start_idx : total_len] = speech_codec - - # Fill audio_labels (left padding) - labels_speech_end_idx = labels_start_idx + speech_codec_len - # Note: The beginning part remains pad_token_id - audio_labels[i, labels_start_idx : labels_speech_end_idx] = speech_codec - audio_labels[i, labels_speech_end_idx] = self.codec_lm.config.eos_token_id - else: - raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}") - - audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) # TODO: do we need to change bos tokens to pad token or mask token? + audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes) - # input_ids: seq_len T1, audio_codec seq_len T2 - text_last_hidden_outputs = model_outputs.hidden_states[-1] - text_input_embeds = inputs_embeds + text_last_hidden_outputs # TODO: 计算不对,output tokens' embedding? - text_input_embeds = self.speech_token_projector(text_input_embeds) + text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], [] + for i in range(len(text_label_start_index_list)): + text_last_hidden = model_outputs.hidden_states[-1][i, text_input_start_index_list[i]:text_input_start_index_list[i] + input_seq_len[i] - 1] + text_last_hidden_lists.append(text_last_hidden) + text_embed = inputs_embeds[i, text_input_start_index_list[i] + 1:text_input_start_index_list[i] + input_seq_len[i]] # exclude bos + text_embeds_list.append(text_embed) - T_merged = text_input_embeds.shape[1] - T_audio = audio_embeddings.shape[1] - - if self.codec_lm_padding_side == "right": - # Add to the beginning for right padding - audio_embeddings[:, :T_merged] += text_input_embeds - elif self.codec_lm_padding_side == "left": - # Need to add to the shifted position for left padding - # Calculate the length of the non-padded sequence for each item - seq_lens = audio_attention_mask.sum(dim=1) # Shape (B) - print(seq_lens[0], audio_codes[0], "======================") - for i in range(audio_embeddings.shape[0]): - item_len = seq_lens[i].item() # Get the non-padded length for item i - start_idx_content = T_audio - item_len # Start index of the content for item i - end_idx_target = start_idx_content + T_merged # End index of the target slice within the content - # Add the text_input_embeds to the calculated slice - if end_idx_target > T_audio: - # If the text input is longer than the audio input, we need to pad the audio input - cut_off_len = T_audio - start_idx_content - audio_embeddings[i, start_idx_content:end_idx_target] = text_input_embeds[i, :cut_off_len] - else: - audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i] - else: - raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}") + text_input_embeds = torch.cat( + [ + text_last_hidden, + text_embed, + ], + dim=-1, + )# shape, T, D1 + D2 + text_input_embeds = self.speech_token_projector(text_input_embeds) # shape, T, D_codec + text_input_embeds_list.append(text_input_embeds) + + for i in range(audio_embeddings.shape[0]): + text_input_embeds = text_input_embeds_list[i] + if self.codec_lm_padding_side == "right": + audio_embeddings[i, :text_input_embeds.shape[0]] += text_input_embeds + elif self.codec_lm_padding_side == "left": + start_idx = torch.where(audio_codes[i] == self.codec_lm.config.mask_token_id)[0][0] + start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0] + assert start_idx == start_idx_re_compute, f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}" + if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx: + text_input_embeds = text_input_embeds[:audio_embeddings.shape[1] - start_idx] + logging.warning(f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}") + audio_embeddings[i, start_idx:start_idx + text_input_embeds.shape[0]] += text_input_embeds speech_outputs = self.codec_lm( attention_mask=audio_attention_mask, @@ -545,26 +527,56 @@ class SPEECH_LLM(nn.Module): output_hidden_states=True, **final_llm_kwargs ) - + delay_step = 1 generated_text_ids = text_outputs.sequences # [B, S_full] - thinker_token_embeds = [ + eos_token_id = self.llm.config.eos_token_id + eos_token_embedding = self.llm.get_input_embeddings()(torch.tensor([[eos_token_id]], device=device)) # 1,D + assert generated_text_ids[0, -1] == eos_token_id, f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}" + thinker_token_embeds_org = [ token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states ] + # shift one for thinker token_embeds, drop the first embeds, and add the eos token + first_thinker_token_embed = torch.cat( + [ + thinker_token_embeds_org[0][:, 1:], + thinker_token_embeds_org[1], + ], + dim=1, + ) + + thinker_token_embeds = [first_thinker_token_embed] + thinker_token_embeds_org[2:] + [eos_token_embedding] thinker_hidden_states = [ token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states ] - thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1) - thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0] + # thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1) + thinker_reply_part = [torch.cat( + [ + thinker_hidden_state, + thinker_token_embed, + ], + dim=-1, + ) + for thinker_hidden_state, thinker_token_embed in zip(thinker_hidden_states[1:], thinker_token_embeds[1:]) + ] + thinker_reply_part = torch.cat(thinker_reply_part, dim=1) + # thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0] + thinker_prompt_part = torch.cat( + [ + thinker_hidden_states[0], + thinker_token_embeds[0], + ], + dim=-1, + ) thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec] thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec] - - delay_step = 2 + thinker_prompt_part_seq_len = thinker_prompt_part.shape[1] talker_input_ids = torch.full( - (batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device + (batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.mask_token_id, dtype=torch.long, device=self.llm.device ) + talker_input_ids[:,-1] = self.codec_lm.config.bos_token_id talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec] thinker_input_embeds = torch.cat( [ @@ -614,7 +626,7 @@ class SPEECH_LLM(nn.Module): # Get logits for the *last* token generated in this step next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index # suppress tokens between 4096:len(vocab)-3 - next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens? + # next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens? next_token_ids = topk_sampling( next_token_logits, ) diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py index add33b52a..f5356dc43 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py @@ -745,7 +745,7 @@ def run(rank, world_size, args): # attn_implementation=attn_implementation, # torch_dtype=torch_dtype, # ) - codec_vocab_size = 8192 + codec_vocab_size = 4096 + 4 # TODO: modify above vocab size or supress_tokens when decoding config = Qwen2Config( vocab_size=codec_vocab_size, @@ -770,24 +770,25 @@ def run(rank, world_size, args): codec_lm.config.pad_token_id = codec_vocab_size - 1 codec_lm.config.eos_token_id = codec_vocab_size - 2 codec_lm.config.bos_token_id = codec_vocab_size - 3 - if params.use_lora: - lora_config = LoraConfig( - r=64, - lora_alpha=16, - target_modules=[ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "up_proj", - "gate_proj", - "down_proj", - ], - lora_dropout=0.05, - task_type="CAUSAL_LM", - ) - codec_lm = get_peft_model(codec_lm, lora_config) - codec_lm.print_trainable_parameters() + codec_lm.config.mask_token_id = codec_vocab_size - 4 + # if params.use_lora: + # lora_config = LoraConfig( + # r=64, + # lora_alpha=16, + # target_modules=[ + # "q_proj", + # "k_proj", + # "v_proj", + # "o_proj", + # "up_proj", + # "gate_proj", + # "down_proj", + # ], + # lora_dropout=0.05, + # task_type="CAUSAL_LM", + # ) + # codec_lm = get_peft_model(codec_lm, lora_config) + # codec_lm.print_trainable_parameters() else: codec_lm = None