diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py index 66ccd9974..7de6f7b5d 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py @@ -346,10 +346,19 @@ def decode_one_batch( messages.append(message) input_ids, attention_mask = preprocess(messages, tokenizer) - - generated_ids = model.decode( - feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) - ) + if params.enable_speech_output: + generated_ids, generated_speech_output = model.decode_with_speech_output( + feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) + ) + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + for cut_id, speech_output in zip(cut_ids, generated_speech_output): + # save_path = params.exp_dir / f"speech_output/{cut_id}.wav" + #torchaudio.save(save_path, speech_output.cpu(), 16000) + print(f"speech_output: {speech_output}, cut_id: {cut_id}") + else: + generated_ids = model.decode( + feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) + ) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) return {"beam-search": hyps} @@ -586,10 +595,71 @@ def main(): speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate ) + if params.enable_speech_output: + # Determine attn_implementation and torch_dtype based on use_flash_attn + if params.use_flash_attn: + attn_implementation = "flash_attention_2" + torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported + else: + attn_implementation = "eager" + torch_dtype = torch.float16 + + # codec_lm = AutoModelForCausalLM.from_pretrained( + # params.llm_path_or_name, + # attn_implementation=attn_implementation, + # torch_dtype=torch_dtype, + # ) + codec_vocab_size = 8192 + config = Qwen2Config( + vocab_size=codec_vocab_size, + hidden_size=1024, + num_hidden_layers=12, + num_attention_heads=16, + num_key_value_heads=16, + intermediate_size=2048, + max_position_embeddings=4096, + ) + # codec_lm = Qwen2ForCausalLM(config=config) + # Pass attn_implementation and torch_dtype to the constructor + # Use AutoModelForCausalLM.from_config for more generality + codec_lm = AutoModelForCausalLM.from_config( + config=config, + attn_implementation=attn_implementation, + torch_dtype=torch_dtype + ) + # cosyvoice2_token_size = 6561 + codec_lm.resize_token_embeddings(codec_vocab_size) + codec_lm.vocab_size = codec_vocab_size + codec_lm.config.pad_token_id = codec_vocab_size - 1 + codec_lm.config.eos_token_id = codec_vocab_size - 2 + codec_lm.config.bos_token_id = codec_vocab_size - 3 + if params.use_lora: + lora_config = LoraConfig( + r=64, + lora_alpha=16, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "up_proj", + "gate_proj", + "down_proj", + ], + lora_dropout=0.05, + task_type="CAUSAL_LM", + ) + codec_lm = get_peft_model(codec_lm, lora_config) + codec_lm.print_trainable_parameters() + else: + codec_lm = None + model = SPEECH_LLM( speech_encoder, llm, encoder_projector, + codec_lm, + codec_lm_padding_side= "left" if params.use_flash_attn else "right", ) if params.avg > 1: diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py index fb3921ba3..8ec707583 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py @@ -1,6 +1,7 @@ import torch from torch import nn from transformers.trainer_pt_utils import LabelSmoother +from typing import List, Tuple # Added for type hints IGNORE_TOKEN_ID = LabelSmoother.ignore_index @@ -444,6 +445,168 @@ class SPEECH_LLM(nn.Module): # ) return generated_ids + def decode_with_speech_output( + self, + fbank: torch.Tensor = None, + input_ids: torch.LongTensor = None, # Prompt input_ids + attention_mask: torch.Tensor = None, # Prompt attention_mask + max_text_new_tokens: int = 1024, + max_speech_new_tokens: int = 1024, # Max length for speech tokens + llm_kwargs: dict = None, # Kwargs for text LLM generate + codec_lm_kwargs: dict = None # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET + ) -> Tuple[torch.LongTensor, List[List[int]]]: + """ + Generates text and corresponding speech tokens using the revised logic. + + Args: + fbank: Input audio features. + input_ids: Input token IDs for the text prompt. + attention_mask: Attention mask for the text prompt. + max_text_new_tokens: Max new tokens for text generation. + max_speech_new_tokens: Max new tokens for speech generation. + llm_kwargs: Additional arguments for self.llm.generate. + codec_lm_kwargs: Additional arguments for self.codec_lm.generate. + + Returns: + Tuple[torch.LongTensor, List[List[int]]]: + - generated_text_ids: Tensor of generated text token IDs (including prompt). + - generated_speech_tokens: List of lists, where each inner list contains + the generated speech codec tokens for a batch item. + """ + if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head: + raise ValueError("codec_lm and associated layers must be initialized to generate speech output.") + + device = next(self.parameters()).device # Use model's device + batch_size = fbank.shape[0] + + # --- 1. Prepare Prompt Embeddings --- + encoder_outs = self.encoder(fbank) + speech_features = self.encoder_projector(encoder_outs) + speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype + + prompt_embeds = self.llm.get_input_embeddings()(input_ids) + + # Merge speech features with prompt embeddings + ( + merged_prompt_inputs_embeds, + merged_prompt_attention_mask, + _, + _, + ) = self._merge_input_ids_with_speech_features( + speech_features, prompt_embeds, input_ids, attention_mask + ) + + # --- 2. Generate Text using LLM --- + # Use merged embeds/mask as input to generate + # Ensure kwargs passed are suitable for llm.generate + # Note: Using default generation params from `decode` if not provided in kwargs + final_llm_kwargs = { + "bos_token_id": self.llm.config.bos_token_id, + "eos_token_id": self.llm.config.eos_token_id, + "pad_token_id": self.llm.config.pad_token_id, + "num_beams": 1, + "do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed + "top_p": 0.5, + "top_k": 20, + "repetition_penalty": 1.1, + "temperature": 0.7, + **(llm_kwargs or {}) # User-provided kwargs override defaults + } + + text_outputs = self.llm.generate( + inputs_embeds=merged_prompt_inputs_embeds, + attention_mask=merged_prompt_attention_mask, + max_new_tokens=max_text_new_tokens, + return_dict_in_generate=True, + **final_llm_kwargs + ) + for key in text_outputs: + print(key, text_outputs[key].shape) + # Assume text_outputs is the tensor of generated IDs [B, S_full] + generated_text_ids = text_outputs + exit(0) + + # --- 3. Get LLM Hidden States for the *Full* Generated Text Sequence --- + # Run a separate forward pass to reliably get hidden states for the complete sequence. + # This is simpler than parsing the complex output of generate with output_hidden_states=True. + full_text_embeds = self.llm.get_input_embeddings()(generated_text_ids) # [B, S_full, D_llm] + full_text_attention_mask = (generated_text_ids != self.llm.config.pad_token_id).long() # [B, S_full] + + # --- 4. Project Hidden States --- + projected_text_embeds = self.speech_token_projector(full_text_embeds) # Shape [B, S_full, D_codec] + + # --- 5. Generate Speech Tokens (Autoregressive Loop with Text Context) --- + self.codec_lm.to(device) + self.codec_lm_head.to(device) + + # Initial input for the codec LM is the BOS token + current_speech_input_ids = torch.full( + (batch_size, 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=device + ) + + past_key_values = None + generated_speech_tokens_list = [[] for _ in range(batch_size)] + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device) + + text_context_len = projected_text_embeds.shape[1] # S_full + + for t in range(max_speech_new_tokens): + # Get embedding for the *current* input token ID (initially BOS, then generated tokens) + current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec] + + # Add the projected text embedding corresponding to the current timestep `t` + if t < text_context_len: + # Text context from the full generated text sequence + current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec] + inputs_embeds = current_speech_embeds + current_text_context_embed + else: + # No more text context to add + inputs_embeds = current_speech_embeds + + # Ensure inputs_embeds has the correct dtype for the codec_lm + inputs_embeds = inputs_embeds.to(next(self.codec_lm.parameters()).dtype) + + # Forward pass through codec LM for one step + # We provide inputs_embeds directly, bypassing prepare_inputs_for_generation + codec_outputs = self.codec_lm( + inputs_embeds=inputs_embeds, # Combined embedding for this step + past_key_values=past_key_values, + use_cache=True, + return_dict=True, + # No attention mask needed here when using past_key_values and single token input + ) + + # Get logits for the *last* token generated in this step + next_token_logits = self.codec_lm_head(codec_outputs.last_hidden_state[:, -1:, :]) # Use -1 index + + # --- Process Output & Update State --- + # Greedy decoding (can be replaced with sampling based on codec_lm_kwargs) + # TODO: Implement sampling/beam search for codec LM if needed + next_token_ids = torch.argmax(next_token_logits, dim=-1) # Greedy [B, 1] + + # Mask out finished sequences + next_token_ids = next_token_ids * unfinished_sequences.unsqueeze(-1) + \ + self.codec_lm.config.pad_token_id * (1 - unfinished_sequences.unsqueeze(-1)) + + # Store generated tokens for unfinished sequences + for i in range(batch_size): + if unfinished_sequences[i]: + token_id = next_token_ids[i].item() + if token_id == self.codec_lm.config.eos_token_id: + unfinished_sequences[i] = 0 # Mark as finished + elif token_id != self.codec_lm.config.pad_token_id: + generated_speech_tokens_list[i].append(token_id) + + # Prepare for next iteration + current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step + past_key_values = codec_outputs.past_key_values # Update KV cache + + # Stop if all sequences are finished + if unfinished_sequences.max() == 0: + break + + # --- 6. Return Results --- + return generated_text_ids, generated_speech_tokens_list def compute_accuracy(pad_outputs, pad_targets, ignore_label): """Calculate accuracy. diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py index ef7e7a464..7b2e0e1f4 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py @@ -133,20 +133,6 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Whether to use lora to fine-tune llm.", ) - parser.add_argument( - "--unfreeze-llm", - type=str2bool, - default=False, - help="Whether to unfreeze llm during training.", - ) - - parser.add_argument( - "--unfreeze-speech-projector", - type=str2bool, - default=False, - help="Whether to unfreeze speech adaptor during training.", - ) - parser.add_argument( "--enable-speech-output", type=str2bool, @@ -224,6 +210,19 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--unfreeze-llm", + type=str2bool, + default=False, + help="Whether to unfreeze llm during training.", + ) + + parser.add_argument( + "--unfreeze-speech-projector", + type=str2bool, + default=False, + help="Whether to unfreeze speech adaptor during training.", + ) parser = deepspeed.add_config_arguments(parser) add_model_arguments(parser)