From 478d56efd8088c1c24cc82f275e2fbd18bf158ee Mon Sep 17 00:00:00 2001 From: root Date: Wed, 23 Apr 2025 07:33:27 +0000 Subject: [PATCH] fix bugs when padding right --- egs/speech_llm/SPEECH2SPEECH/prepare.sh | 52 ++--- .../SPEECH2SPEECH/slam_omni/decode.py | 58 ++--- .../SPEECH2SPEECH/slam_omni/model.py | 219 ++++++++++++------ .../SPEECH2SPEECH/slam_omni/train.py | 6 +- 4 files changed, 187 insertions(+), 148 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index ef0e87465..7e145865e 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -35,8 +35,8 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then fi -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 3: Combine features" +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Combine features" manifest_dir=data/fbank if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort) @@ -48,39 +48,27 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then fi fi -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "stage 2: " + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "stage 3: " python3 ./slam_omni/decode.py \ - --max-duration 80 \ - --exp-dir slam_omni/exp_speech2text \ + --max-duration 1 \ + --exp-dir slam_omni/exp_speech2speech_test_flash_attn \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ - --epoch 999 --avg 1 \ + --epoch 997 --avg 1 \ --manifest-dir data/fbank \ --use-flash-attn True \ - --method pure_text_sampling \ + --method small_test_speech2speech \ + --enable-speech-output True \ --use-lora True # --on-the-fly-feats True fi -if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then - log "stage 2: " - python3 ./slam_omni/decode.py \ - --max-duration 80 \ - --exp-dir slam_omni/exp_speech2text \ - --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ - --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ - --epoch 999 --avg 1 \ - --manifest-dir data/fbank \ - --use-flash-attn True \ - --method pure_text_sampling_original_0.5B \ - --use-lora False # --on-the-fly-feats True -fi - -ngpu=8 -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "stage 3: " +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "stage 4: " + ngpu=8 torchrun --nproc_per_node $ngpu ./slam_omni/train.py \ --max-duration 80 \ --enable-musan False \ @@ -97,21 +85,23 @@ torchrun --nproc_per_node $ngpu ./slam_omni/train.py \ fi -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "stage 4: " +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "stage 5: " ngpu=2 + exp_dir=./slam_omni/exp_speech2speech_test_flash_attn torchrun --nproc_per_node $ngpu ./slam_omni/train.py \ --max-duration 40 \ --enable-musan False \ - --exp-dir ./slam_omni/exp_speech2text \ + --exp-dir $exp_dir \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ --manifest-dir data/fbank \ --deepspeed \ --deepspeed_config ./slam_omni/ds_config_zero1.json \ - --use-flash-attn False \ - --use-lora True --unfreeze-llm False --enable-speech-output True + --use-flash-attn True \ + --pretrained-model-path $exp_dir/epoch-1-checkpoint-35000.pt/pytorch_model.bin \ + --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True # --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \ - # --sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \ + # --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \ fi \ No newline at end of file diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py index 7de6f7b5d..2727b330b 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py @@ -62,9 +62,9 @@ from model import SPEECH_LLM, EncoderProjector from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from train import DEFAULT_SPEECH_TOKEN -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward - +from train import add_model_arguments from icefall.env import get_env_info from icefall.utils import ( AttributeDict, @@ -126,43 +126,6 @@ def average_checkpoints( return avg -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--llm-path-or-name", - type=str, - default="", - help="Path or name of the large language model.", - ) - - parser.add_argument( - "--speech-encoder-path-or-name", - type=str, - default="whisper-large-v2", - help="Path or name of the speech encoder.", - ) - - parser.add_argument( - "--encoder-projector-ds-rate", - type=int, - default=8, - help="Downsample rate for the encoder projector.", - ) - - parser.add_argument( - "--use-flash-attn", - type=str2bool, - default=True, - help="Whether to use flash attention.", - ) - - parser.add_argument( - "--use-lora", - type=str2bool, - default=True, - help="Whether to use lora fine-tuned llm checkpoint.", - ) - - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -351,16 +314,21 @@ def decode_one_batch( feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) ) cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - for cut_id, speech_output in zip(cut_ids, generated_speech_output): - # save_path = params.exp_dir / f"speech_output/{cut_id}.wav" - #torchaudio.save(save_path, speech_output.cpu(), 16000) - print(f"speech_output: {speech_output}, cut_id: {cut_id}") + with open("test.txt", 'w') as f: + for cut_id in cut_ids: + # save_path = params.exp_dir / f"speech_output/{cut_id}.wav" + #torchaudio.save(save_path, speech_output.cpu(), 16000) + print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}") + save_str = " ".join([str(i) for i in generated_speech_output]) + f.write(f"{cut_id}|{save_str}\n") + else: generated_ids = model.decode( feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) ) - hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - + hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=False) + print(f"hyps: {hyps}") + exit(0) return {"beam-search": hyps} diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py index 8ec707583..3f93db154 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py @@ -2,7 +2,7 @@ import torch from torch import nn from transformers.trainer_pt_utils import LabelSmoother from typing import List, Tuple # Added for type hints - +from torchmetrics.classification import MulticlassAccuracy IGNORE_TOKEN_ID = LabelSmoother.ignore_index @@ -74,9 +74,21 @@ class SPEECH_LLM(nn.Module): self.codec_lm_head = nn.Linear( self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size ) + # to torch.float16 + self.speech_token_projector = self.speech_token_projector.to( + dtype=torch.float16 + ) + self.codec_lm_head = self.codec_lm_head.to(dtype=torch.float16) self.loss_fct = torch.nn.CrossEntropyLoss() self.codec_lm_padding_side = codec_lm_padding_side + self.audio_accuracy_metric = MulticlassAccuracy( + self.codec_lm.vocab_size, + top_k=10, + average="micro", + multidim_average="global", + ignore_index=IGNORE_TOKEN_ID, + ) def _merge_input_ids_with_speech_features( self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None ): @@ -332,7 +344,7 @@ class SPEECH_LLM(nn.Module): else: raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}") - audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) + audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) # TODO: do we need to change bos tokens to pad token or mask token? audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes) # input_ids: seq_len T1, audio_codec seq_len T2 @@ -355,7 +367,12 @@ class SPEECH_LLM(nn.Module): start_idx_content = T_audio - item_len # Start index of the content for item i end_idx_target = start_idx_content + T_merged # End index of the target slice within the content # Add the text_input_embeds to the calculated slice - audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i] + if end_idx_target > T_audio: + # If the text input is longer than the audio input, we need to pad the audio input + cut_off_len = T_audio - start_idx_content + audio_embeddings[i, start_idx_content:end_idx_target] = text_input_embeds[i, :cut_off_len] + else: + audio_embeddings[i, start_idx_content:end_idx_target] += text_input_embeds[i] else: raise ValueError(f"Unsupported padding side: {self.codec_lm_padding_side}") @@ -389,9 +406,12 @@ class SPEECH_LLM(nn.Module): audio_labels.detach(), ignore_label=IGNORE_TOKEN_ID, ) + audio_topk_acc = self.audio_accuracy_metric( + audio_logits.detach(), + audio_labels.detach()).item() - return text_loss, acc, codec_loss, audio_acc + return text_loss, acc, codec_loss, audio_acc, audio_topk_acc def decode( self, @@ -473,6 +493,7 @@ class SPEECH_LLM(nn.Module): - generated_speech_tokens: List of lists, where each inner list contains the generated speech codec tokens for a batch item. """ + assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation." if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head: raise ValueError("codec_lm and associated layers must be initialized to generate speech output.") @@ -518,93 +539,88 @@ class SPEECH_LLM(nn.Module): attention_mask=merged_prompt_attention_mask, max_new_tokens=max_text_new_tokens, return_dict_in_generate=True, + output_hidden_states=True, **final_llm_kwargs ) - for key in text_outputs: - print(key, text_outputs[key].shape) - # Assume text_outputs is the tensor of generated IDs [B, S_full] - generated_text_ids = text_outputs - exit(0) - # --- 3. Get LLM Hidden States for the *Full* Generated Text Sequence --- - # Run a separate forward pass to reliably get hidden states for the complete sequence. - # This is simpler than parsing the complex output of generate with output_hidden_states=True. - full_text_embeds = self.llm.get_input_embeddings()(generated_text_ids) # [B, S_full, D_llm] - full_text_attention_mask = (generated_text_ids != self.llm.config.pad_token_id).long() # [B, S_full] + generated_text_ids = text_outputs.sequences # [B, S_full] + thinker_token_embeds = [ + token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states + ] + thinker_hidden_states = [ + token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states + ] + thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1) + thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0] - # --- 4. Project Hidden States --- - projected_text_embeds = self.speech_token_projector(full_text_embeds) # Shape [B, S_full, D_codec] - - # --- 5. Generate Speech Tokens (Autoregressive Loop with Text Context) --- - self.codec_lm.to(device) - self.codec_lm_head.to(device) - - # Initial input for the codec LM is the BOS token - current_speech_input_ids = torch.full( - (batch_size, 1), self.codec_lm.config.bos_token_id, dtype=torch.long, device=device + thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec] + thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec] + + + delay_step = 2 + thinker_prompt_part_seq_len = thinker_prompt_part.shape[1] + talker_input_ids = torch.full( + (batch_size, thinker_prompt_part_seq_len + delay_step), self.codec_lm.config.bos_token_id, dtype=torch.long, device=self.llm.device ) + talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec] + thinker_input_embeds = torch.cat( + [ + thinker_prompt_part, + thinker_reply_part[:, :delay_step, :], + ], + dim=1, + ) + talker_inputs_embeds += thinker_input_embeds + thinker_reply_part = thinker_reply_part[:, delay_step:, :] # [B, S_full, D_codec] + past_key_values = None - generated_speech_tokens_list = [[] for _ in range(batch_size)] - unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device) - - text_context_len = projected_text_embeds.shape[1] # S_full - + # generated_speech_tokens_list = [[] for _ in range(batch_size)] + # unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device) + generated_speech_tokens_list = [] + next_token_ids = None + # text_context_len = projected_text_embeds.shape[1] # S_full for t in range(max_speech_new_tokens): # Get embedding for the *current* input token ID (initially BOS, then generated tokens) - current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec] - - # Add the projected text embedding corresponding to the current timestep `t` - if t < text_context_len: - # Text context from the full generated text sequence - current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec] - inputs_embeds = current_speech_embeds + current_text_context_embed - else: - # No more text context to add - inputs_embeds = current_speech_embeds + # current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec] + if next_token_ids is not None: + talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec] + if thinker_reply_part.shape[1] > 0: + talker_inputs_embeds += thinker_reply_part[:, :1, :] + thinker_reply_part = thinker_reply_part[:, 1:, :] # Remove the first token for next step + # # Add the projected text embedding corresponding to the current timestep `t` + # if t < text_context_len: + # # Text context from the full generated text sequence + # current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec] + # inputs_embeds = current_speech_embeds + current_text_context_embed + # else: + # # No more text context to add + # inputs_embeds = current_speech_embeds - # Ensure inputs_embeds has the correct dtype for the codec_lm - inputs_embeds = inputs_embeds.to(next(self.codec_lm.parameters()).dtype) - # Forward pass through codec LM for one step # We provide inputs_embeds directly, bypassing prepare_inputs_for_generation codec_outputs = self.codec_lm( - inputs_embeds=inputs_embeds, # Combined embedding for this step + inputs_embeds=talker_inputs_embeds, # Combined embedding for this step past_key_values=past_key_values, use_cache=True, return_dict=True, + output_hidden_states=True, # No attention mask needed here when using past_key_values and single token input ) - + last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec] # Get logits for the *last* token generated in this step - next_token_logits = self.codec_lm_head(codec_outputs.last_hidden_state[:, -1:, :]) # Use -1 index - - # --- Process Output & Update State --- - # Greedy decoding (can be replaced with sampling based on codec_lm_kwargs) - # TODO: Implement sampling/beam search for codec LM if needed - next_token_ids = torch.argmax(next_token_logits, dim=-1) # Greedy [B, 1] - - # Mask out finished sequences - next_token_ids = next_token_ids * unfinished_sequences.unsqueeze(-1) + \ - self.codec_lm.config.pad_token_id * (1 - unfinished_sequences.unsqueeze(-1)) - - # Store generated tokens for unfinished sequences - for i in range(batch_size): - if unfinished_sequences[i]: - token_id = next_token_ids[i].item() - if token_id == self.codec_lm.config.eos_token_id: - unfinished_sequences[i] = 0 # Mark as finished - elif token_id != self.codec_lm.config.pad_token_id: - generated_speech_tokens_list[i].append(token_id) - - # Prepare for next iteration - current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step - past_key_values = codec_outputs.past_key_values # Update KV cache - - # Stop if all sequences are finished - if unfinished_sequences.max() == 0: + next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index + # suppress tokens between 4096:len(vocab)-3 + next_token_logits[:, 4096:-3] = -float("Inf") + next_token_ids = topk_sampling( + next_token_logits, + ) + print(next_token_ids, "next_token_ids", t, next_token_ids.shape) + if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id: break - + # current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step + past_key_values = codec_outputs.past_key_values # Update KV cache + generated_speech_tokens_list.append(next_token_ids.squeeze(1).cpu().tolist()[0]) # --- 6. Return Results --- return generated_text_ids, generated_speech_tokens_list @@ -626,3 +642,64 @@ def compute_accuracy(pad_outputs, pad_targets, ignore_label): ) denominator = torch.sum(mask) return numerator.float() / denominator.float() + + +def topk_sampling( + logits, + top_k=50, + top_p=0.95, + temperature=0.8, +): + if temperature != 1.0: + logits = logits / temperature + # Top-p/top-k filtering + logits_filtered = top_k_top_p_filtering( + logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2 + ) + # Sample + probs = torch.nn.functional.softmax(logits_filtered, dim=-1) + tokens = torch.multinomial(probs, num_samples=1) + + return tokens + + +# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py +def top_k_top_p_filtering( + logits, top_k=20, top_p=0.5, filter_value=-float("Inf"), min_tokens_to_keep=1 +): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1 + ) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = filter_value + return logits \ No newline at end of file diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py index 7b2e0e1f4..c9ecf9400 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py @@ -417,7 +417,7 @@ def compute_loss( labels=target_ids.to(device), ) else: - text_loss, acc, codec_loss, codec_acc = model.forward_with_speech_output( + text_loss, acc, codec_loss, codec_acc, codec_topk_acc = model.forward_with_speech_output( fbank=feature, input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), @@ -442,6 +442,9 @@ def compute_loss( info["codec_acc"] = ( codec_acc * info["frames"] ) + info["codec_topk_acc"] = ( + codec_topk_acc * info["frames"] + ) info["codec_loss"] = codec_loss.detach().cpu().item() info["text_loss"] = text_loss.detach().cpu().item() return loss, info @@ -743,6 +746,7 @@ def run(rank, world_size, args): # torch_dtype=torch_dtype, # ) codec_vocab_size = 8192 + # TODO: modify above vocab size or supress_tokens when decoding config = Qwen2Config( vocab_size=codec_vocab_size, hidden_size=1024,