diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index 75bd9c576..ef0e87465 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -96,3 +96,22 @@ torchrun --nproc_per_node $ngpu ./slam_omni/train.py \ --use-lora True --unfreeze-llm True fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "stage 4: " + ngpu=2 +torchrun --nproc_per_node $ngpu ./slam_omni/train.py \ + --max-duration 40 \ + --enable-musan False \ + --exp-dir ./slam_omni/exp_speech2text \ + --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ + --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ + --manifest-dir data/fbank \ + --deepspeed \ + --deepspeed_config ./slam_omni/ds_config_zero1.json \ + --use-flash-attn False \ + --use-lora True --unfreeze-llm False --enable-speech-output True + # --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \ + # --sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \ + +fi \ No newline at end of file diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py index 55541f03e..f7e436806 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py @@ -58,11 +58,21 @@ class SPEECH_LLM(nn.Module): encoder: nn.Module, llm: nn.Module, encoder_projector: nn.Module, + codec_lm: nn.Module = None, ): super().__init__() self.encoder = encoder self.llm = llm self.encoder_projector = encoder_projector + self.codec_lm = codec_lm + if self.codec_lm: + self.speech_token_projector = nn.Linear( + self.llm.config.hidden_size, self.codec_lm.config.hidden_size + ) + self.codec_lm_head = nn.Linear( + self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size + ) + self.loss_fct = torch.nn.CrossEntropyLoss() def _merge_input_ids_with_speech_features( self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None @@ -225,8 +235,112 @@ class SPEECH_LLM(nn.Module): labels.detach()[:, 1:], ignore_label=IGNORE_TOKEN_ID, ) - return model_outputs, acc + return model_outputs.loss, acc + def forward_with_speech_output( + self, + fbank: torch.Tensor = None, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor = None, + labels: torch.LongTensor = None, + speech_codec_ids: torch.LongTensor = None, + ): + encoder_outs = self.encoder(fbank) + + speech_features = self.encoder_projector(encoder_outs) + + inputs_embeds = self.llm.get_input_embeddings()(input_ids) + + ( + inputs_embeds, + attention_mask, + labels, + _, + ) = self._merge_input_ids_with_speech_features( + speech_features, inputs_embeds, input_ids, attention_mask, labels + ) + + # get the label start_index in inputs_embeds from labels + text_label_start_index_list = [] + for i in range(labels.shape[0]): + text_label_start_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0][0] + text_label_start_index_list.append(text_label_start_index) + + model_outputs = self.llm( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True + ) + text_loss = model_outputs.loss + + # prepare codec lm inputs + audio_codes_lens = torch.tensor( + [len(x) for x in speech_codec_ids], dtype=torch.int64, device=input_ids.device + ) + # print(audio_codes_lens, "audio_codes_lens") + max_len_speech_codec = max(audio_codes_lens) + delay_step = 2 + audio_codes = torch.full( + (inputs_embeds.shape[0], max_len_speech_codec + inputs_embeds.shape[1] + 1), + self.codec_lm.config.pad_token_id, + dtype=torch.int64, + device=input_ids.device + ) + audio_labels = audio_codes.clone() + + for i, speech_codec in enumerate(speech_codec_ids): + text_label_start_index = text_label_start_index_list[i] + speech_codec = torch.tensor( + speech_codec, dtype=torch.int64, device=input_ids.device + ) + # print(inputs_embeds[i, text_label_start_index], "2333 test") + audio_codes[i, :text_label_start_index + delay_step + 1] = self.codec_lm.config.bos_token_id # mask token_id + audio_codes[i, text_label_start_index + delay_step + 1 : text_label_start_index + delay_step + 1 + len(speech_codec)] = speech_codec + audio_labels[i, text_label_start_index + delay_step : text_label_start_index + delay_step + len(speech_codec)] = speech_codec + audio_labels[i, text_label_start_index + delay_step + len(speech_codec)] = self.codec_lm.config.eos_token_id + + audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) + audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes) + + # input_ids: seq_len T1, audio_codec seq_len T2 + text_last_hidden_outputs = model_outputs.hidden_states[-1] + text_input_embeds = inputs_embeds + text_last_hidden_outputs + text_input_embeds = self.speech_token_projector(text_input_embeds) + + audio_embeddings[:, : text_input_embeds.shape[1]] += text_input_embeds + + speech_outputs = self.codec_lm( + attention_mask=audio_attention_mask, + inputs_embeds=audio_embeddings, + return_dict=True, + output_hidden_states=True, + ) + last_hidden_state = speech_outputs.hidden_states[-1].clone() + + audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size + audio_logits = audio_logits.contiguous().view(-1, self.codec_lm.config.vocab_size) + audio_labels = audio_labels.contiguous().view(-1) + audio_labels = audio_labels.masked_fill( + audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID + ) + codec_loss = self.loss_fct(audio_logits, audio_labels) + audio_preds = torch.argmax(audio_logits, -1) + + + with torch.no_grad(): + preds = torch.argmax(model_outputs.logits, -1) + acc = compute_accuracy( + preds.detach()[:, :-1], + labels.detach()[:, 1:], + ignore_label=IGNORE_TOKEN_ID, + ) + audio_acc = compute_accuracy( + audio_preds.detach(), + audio_labels.detach(), + ignore_label=IGNORE_TOKEN_ID, + ) + + + return text_loss, acc, codec_loss, audio_acc + def decode( self, fbank: torch.Tensor = None, diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py index f0df303e4..9823492bf 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py @@ -70,7 +70,12 @@ from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from torch import Tensor from torch.utils.tensorboard import SummaryWriter -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + Qwen2Config, + Qwen2ForCausalLM, +) from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from icefall import diagnostics @@ -135,6 +140,19 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Whether to unfreeze llm during training.", ) + parser.add_argument( + "--unfreeze-speech-projector", + type=str2bool, + default=False, + help="Whether to unfreeze speech adaptor during training.", + ) + + parser.add_argument( + "--enable-speech-output", + type=str2bool, + default=False, + help="Whether to enable speech codec output.", + ) def get_parser(): parser = argparse.ArgumentParser( @@ -307,7 +325,7 @@ def compute_loss( ) # padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id # remove too long text - texts = [ text for text in texts if len(text) < 1024 ] + # texts = [ text for text in texts if len(text) < 1024 ] if len(texts) != len(messages): logging.warning( f"Remove too long text, {messages} " @@ -392,13 +410,22 @@ def compute_loss( input_ids = input_ids.type(torch.LongTensor) with torch.set_grad_enabled(is_training): - model_outputs, acc = model( - fbank=feature, - input_ids=input_ids.to(device), - attention_mask=attention_mask.to(device), - labels=target_ids.to(device), - ) - loss = model_outputs.loss + if not params.enable_speech_output: + loss, acc = model( + fbank=feature, + input_ids=input_ids.to(device), + attention_mask=attention_mask.to(device), + labels=target_ids.to(device), + ) + else: + text_loss, acc, codec_loss, codec_acc = model.forward_with_speech_output( + fbank=feature, + input_ids=input_ids.to(device), + attention_mask=attention_mask.to(device), + labels=target_ids.to(device), + speech_codec_ids=answer_cosyvoice_speech_token, + ) + loss = text_loss + codec_loss assert loss.requires_grad == is_training info = MetricsTracker() @@ -412,7 +439,12 @@ def compute_loss( info["acc"] = ( acc * info["frames"] ) # WAR: to avoid normalization by the number of frames - + if params.enable_speech_output: + info["codec_acc"] = ( + codec_acc * info["frames"] + ) + info["codec_loss"] = codec_loss.detach().cpu().item() + info["text_loss"] = text_loss.detach().cpu().item() return loss, info @@ -429,7 +461,7 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast('cuda', enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -544,7 +576,7 @@ def train_one_epoch( f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" ) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast('cuda', enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -629,6 +661,7 @@ def run(rank, world_size, args): speech_encoder.eval() tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) + if params.use_flash_attn: attn_implementation = "flash_attention_2" # torch_dtype=torch.bfloat16 FIX ME @@ -672,6 +705,16 @@ def run(rank, world_size, args): special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]} tokenizer.add_special_tokens(special_tokens_dict) + # original_tokenizer_vocab_size = len(tokenizer) + # cosyvoice2_token_size = 6561 + # new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [ + # "<|SPEECH_GENERATION_START|>" + # ] + # num_added_tokens = tokenizer.add_tokens(new_tokens) + # model.resize_token_embeddings(len(tokenizer)) + # model.vocab_size = len(tokenizer) + + llm.config.pad_token_id = tokenizer.pad_token_id llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( DEFAULT_SPEECH_TOKEN @@ -680,11 +723,66 @@ def run(rank, world_size, args): encoder_projector = EncoderProjector( speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate ) + if not params.unfreeze_speech_projector: + for name, param in encoder_projector.named_parameters(): + param.requires_grad = False + encoder_projector.eval() + + + if params.enable_speech_output: + if params.use_flash_attn: + attn_implementation = "flash_attention_2" + else: + attn_implementation = "eager" + torch_dtype = torch.float16 + # codec_lm = AutoModelForCausalLM.from_pretrained( + # params.llm_path_or_name, + # attn_implementation=attn_implementation, + # torch_dtype=torch_dtype, + # ) + codec_vocab_size = 8192 + config = Qwen2Config( + vocab_size=codec_vocab_size, + hidden_size=1024, + num_hidden_layers=12, + num_attention_heads=16, + num_key_value_heads=16, + intermediate_size=2048, + max_position_embeddings=4096, + ) + codec_lm = Qwen2ForCausalLM(config=config) + # cosyvoice2_token_size = 6561 + codec_lm.resize_token_embeddings(codec_vocab_size) + codec_lm.vocab_size = codec_vocab_size + codec_lm.config.pad_token_id = codec_vocab_size - 1 + codec_lm.config.eos_token_id = codec_vocab_size - 2 + codec_lm.config.bos_token_id = codec_vocab_size - 3 + if params.use_lora: + lora_config = LoraConfig( + r=64, + lora_alpha=16, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "up_proj", + "gate_proj", + "down_proj", + ], + lora_dropout=0.05, + task_type="CAUSAL_LM", + ) + codec_lm = get_peft_model(codec_lm, lora_config) + codec_lm.print_trainable_parameters() + else: + codec_lm = None model = SPEECH_LLM( speech_encoder, llm, encoder_projector, + codec_lm, ) if params.pretrained_model_path: @@ -728,6 +826,13 @@ def run(rank, world_size, args): # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # ) return False + # cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"] + codec_len = len(c.custom["answer_cosyvoice_speech_token"]) + if codec_len > 2048: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}" + ) + return False return True