diff --git a/egs/speech_llm/SPEECH2SPEECH/README.md b/egs/speech_llm/SPEECH2SPEECH/README.md new file mode 100644 index 000000000..e4738eeef --- /dev/null +++ b/egs/speech_llm/SPEECH2SPEECH/README.md @@ -0,0 +1,89 @@ + +# Introduction + +This recipe includes scripts for training speech2speech models. + +# SPEECH2SPEECH + +The following table lists the folders for different tasks. + +|Recipe | Speech Input | Speech Output | Comment| +|--------------|--------------|---------------|--------| +|Qwen-omni like| Continuous Embeddins| Cosyvoice1 50Hz Single-codebook Token | Text-driven; using Thinker LLM for text token, small Talker LLM for speech token | + +### [Qwen-omni like Speech2speech Recipe](./qwen_omni) + +[Qwen2.5-Omni](https://github.com/QwenLM/Qwen2.5-Omni) style model using [worstchan/Belle_1.4M-SLAM-Omni](https://huggingface.co/datasets/worstchan/Belle_1.4M-SLAM-Omni) dataset. + +
+

+ +

+
+ +Command for training is: +```bash +pip install -r whisper_llm_zh/requirements.txt + +pip install huggingface_hub['cli'] +mkdir -p models/whisper models/qwen + +# For aishell fine-tuned whisper model +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt +# For multi-hans fine-tuned whisper model +# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt + +# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct +huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct + +# First, we only train the projector and freeze other modules. +torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ + --max-duration 200 \ + --exp-dir ./whisper_llm_zh/exp_test \ + --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ + --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \ + --manifest-dir data/fbank \ + --deepspeed \ + --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ + --use-flash-attn True \ + --use-lora False --unfreeze-llm False + +# Then we jointly train the projector and LLM LoRA modules. +torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ + --max-duration 200 \ + --exp-dir ./whisper_llm_zh/exp_test \ + --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ + --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \ + --manifest-dir data/fbank \ + --deepspeed \ + --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ + --use-flash-attn True \ + --use-lora True --unfreeze-llm True + --pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt +``` + +Command for decoding: +```bash +mkdir -p models/whisper models/qwen models/checkpoint +huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B + +# For aishell fine-tuned whisper model +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt +# For multi-hans fine-tuned whisper model +# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt + +huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct + +mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B +ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt + +python3 ./whisper_llm_zh/decode.py \ + --max-duration 80 \ + --exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \ + --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ + --llm-path-or-name models/qwen \ + --epoch 999 --avg 1 \ + --manifest-dir data/fbank \ + --use-flash-attn True \ + --use-lora True --dataset aishell +``` diff --git a/egs/speech_llm/SPEECH2SPEECH/assets/framework.jpg b/egs/speech_llm/SPEECH2SPEECH/assets/framework.jpg new file mode 100644 index 000000000..d708bb256 Binary files /dev/null and b/egs/speech_llm/SPEECH2SPEECH/assets/framework.jpg differ diff --git a/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py b/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py index b01a35c7d..4bc5e5a82 100755 --- a/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py +++ b/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py @@ -17,6 +17,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Usage: + python3 local/compute_whisper_fbank.py \ + --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \ + --out-dir data/fbank \ + --huggingface-dataset-path-or-name worstchan/UltraChat-300K-SLAM-Omni \ + --audio-key question_audio --text-key answer \ + --prefix ultrachat +""" + import argparse import logging @@ -126,7 +136,7 @@ def compute_fbank(args): num_digits = 5 for i in range(num_shards): shard = dataset.shard(num_shards, i) - shard = shard.take(10) # for testing + # shard = shard.take(10) # for testing logging.info( f"Loading dataset shard {i} from {args.huggingface_dataset_path_or_name}" ) @@ -159,8 +169,6 @@ def compute_fbank(args): logging.info(f"Saving to {cuts_path}") # see https://github.com/lhotse-speech/lhotse/issues/1125 cut_set.drop_recordings().to_file(cuts_path) - if i > 1: - break def main(): diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py similarity index 100% rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/data_module.py rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py similarity index 100% rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/ds_config_zero1.json b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/ds_config_zero1.json similarity index 100% rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/ds_config_zero1.json rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/ds_config_zero1.json diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/label_smoothing.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/label_smoothing.py similarity index 100% rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/label_smoothing.py rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/label_smoothing.py diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py similarity index 77% rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py index 0cc93c237..97870337d 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/model.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py @@ -1,11 +1,14 @@ +from typing import List, Tuple # Added for type hints + import torch from torch import nn -from transformers.trainer_pt_utils import LabelSmoother -from typing import List, Tuple # Added for type hints from torchmetrics.classification import MulticlassAccuracy +from transformers.trainer_pt_utils import LabelSmoother + IGNORE_TOKEN_ID = LabelSmoother.ignore_index import logging + class EncoderProjector(nn.Module): """ The encoder projector module. It is used to project the encoder outputs to the same dimension as the language model. @@ -69,7 +72,8 @@ class SPEECH_LLM(nn.Module): self.codec_lm = codec_lm if self.codec_lm: self.speech_token_projector = nn.Linear( - self.llm.config.hidden_size + self.llm.config.hidden_size, self.codec_lm.config.hidden_size + self.llm.config.hidden_size + self.llm.config.hidden_size, + self.codec_lm.config.hidden_size, ) self.codec_lm_head = nn.Linear( self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size @@ -89,6 +93,7 @@ class SPEECH_LLM(nn.Module): multidim_average="global", ignore_index=IGNORE_TOKEN_ID, ) + def _merge_input_ids_with_speech_features( self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None ): @@ -274,68 +279,115 @@ class SPEECH_LLM(nn.Module): ) = self._merge_input_ids_with_speech_features( speech_features, inputs_embeds, input_ids, attention_mask, labels ) - input_seq_len = attention_mask.sum(dim=1) # shape, B - text_label_start_index_list, text_input_start_index_list, input_question_len_list = [], [], [] + input_seq_len = attention_mask.sum(dim=1) # shape, B + ( + text_label_start_index_list, + text_input_start_index_list, + input_question_len_list, + ) = ([], [], []) for i in range(labels.shape[0]): input_embeds_valid_index = torch.where(attention_mask[i] != 0)[0] input_embeds_start_index = input_embeds_valid_index[0] text_labels_valid_index = torch.where(labels[i] != IGNORE_TOKEN_ID)[0] text_labels_start_index = text_labels_valid_index[0] - assert input_seq_len[i] == input_embeds_valid_index[-1] - input_embeds_start_index + 1, f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}" - assert input_embeds_valid_index[-1] == text_labels_valid_index[-1], f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}" + assert ( + input_seq_len[i] + == input_embeds_valid_index[-1] - input_embeds_start_index + 1 + ), f"input_seq_len: {input_seq_len[i]}, input_embeds_valid_index: {input_embeds_valid_index}, input_embeds_start_index: {input_embeds_start_index}" + assert ( + input_embeds_valid_index[-1] == text_labels_valid_index[-1] + ), f"input_embeds_valid_index: {input_embeds_valid_index}, text_labels_valid_index: {text_labels_valid_index}" input_question_len = text_labels_start_index - input_embeds_start_index - assert input_question_len + text_labels_valid_index[-1] - text_labels_start_index + 1 == input_seq_len[i] + assert ( + input_question_len + + text_labels_valid_index[-1] + - text_labels_start_index + + 1 + == input_seq_len[i] + ) text_label_start_index_list.append(text_labels_start_index) text_input_start_index_list.append(input_embeds_start_index) input_question_len_list.append(input_question_len) model_outputs = self.llm( - inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + labels=labels, + output_hidden_states=True, ) text_loss = model_outputs.loss delay_step = 1 # prepare codec lm inputs - audio_codes_lens = [len(x) + input_question_len_list[i] + delay_step + 1 for i, x in enumerate(speech_codec_ids)] + audio_codes_lens = [ + len(x) + input_question_len_list[i] + delay_step + 1 + for i, x in enumerate(speech_codec_ids) + ] max_len_speech_codec = max(audio_codes_lens) if self.codec_lm_padding_side == "right": audio_codes = [ - [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.mask_token_id] + * (input_question_len_list[i] + delay_step) + + [self.codec_lm.config.bos_token_id] + + x + + [self.codec_lm.config.pad_token_id] + * (max_len_speech_codec - audio_codes_lens[i]) for i, x in enumerate(speech_codec_ids) ] audio_labels = [ - [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] + [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.pad_token_id] + * (input_question_len_list[i] + delay_step) + + x + + [self.codec_lm.config.eos_token_id] + + [self.codec_lm.config.pad_token_id] + * (max_len_speech_codec - audio_codes_lens[i]) for i, x in enumerate(speech_codec_ids) ] elif self.codec_lm_padding_side == "left": audio_codes = [ - [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.mask_token_id] * (input_question_len_list[i] + delay_step) + [self.codec_lm.config.bos_token_id] + x + [self.codec_lm.config.pad_token_id] + * (max_len_speech_codec - audio_codes_lens[i]) + + [self.codec_lm.config.mask_token_id] + * (input_question_len_list[i] + delay_step) + + [self.codec_lm.config.bos_token_id] + + x for i, x in enumerate(speech_codec_ids) ] audio_labels = [ - [self.codec_lm.config.pad_token_id] * (max_len_speech_codec - audio_codes_lens[i]) + [self.codec_lm.config.pad_token_id] * (input_question_len_list[i] + delay_step) + x + [self.codec_lm.config.eos_token_id] + [self.codec_lm.config.pad_token_id] + * (max_len_speech_codec - audio_codes_lens[i]) + + [self.codec_lm.config.pad_token_id] + * (input_question_len_list[i] + delay_step) + + x + + [self.codec_lm.config.eos_token_id] for i, x in enumerate(speech_codec_ids) ] audio_codes = torch.tensor( - audio_codes, - dtype=torch.int64, - device=input_ids.device + audio_codes, dtype=torch.int64, device=input_ids.device ) audio_labels = torch.tensor( - audio_labels, - dtype=torch.int64, - device=input_ids.device + audio_labels, dtype=torch.int64, device=input_ids.device ) audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id) audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes) - + text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], [] for i in range(len(text_label_start_index_list)): - text_last_hidden = model_outputs.hidden_states[-1][i, text_input_start_index_list[i]:text_input_start_index_list[i] + input_seq_len[i] - 1] + text_last_hidden = model_outputs.hidden_states[-1][ + i, + text_input_start_index_list[i] : text_input_start_index_list[i] + + input_seq_len[i] + - 1, + ] text_last_hidden_lists.append(text_last_hidden) - text_embed = inputs_embeds[i, text_input_start_index_list[i] + 1:text_input_start_index_list[i] + input_seq_len[i]] # exclude bos + text_embed = inputs_embeds[ + i, + text_input_start_index_list[i] + + 1 : text_input_start_index_list[i] + + input_seq_len[i], + ] # exclude bos text_embeds_list.append(text_embed) text_input_embeds = torch.cat( @@ -344,22 +396,34 @@ class SPEECH_LLM(nn.Module): text_embed, ], dim=-1, - )# shape, T, D1 + D2 - text_input_embeds = self.speech_token_projector(text_input_embeds) # shape, T, D_codec + ) # shape, T, D1 + D2 + text_input_embeds = self.speech_token_projector( + text_input_embeds + ) # shape, T, D_codec text_input_embeds_list.append(text_input_embeds) - + for i in range(audio_embeddings.shape[0]): text_input_embeds = text_input_embeds_list[i] if self.codec_lm_padding_side == "right": - audio_embeddings[i, :text_input_embeds.shape[0]] += text_input_embeds + audio_embeddings[i, : text_input_embeds.shape[0]] += text_input_embeds elif self.codec_lm_padding_side == "left": - start_idx = torch.where(audio_codes[i] == self.codec_lm.config.mask_token_id)[0][0] + start_idx = torch.where( + audio_codes[i] == self.codec_lm.config.mask_token_id + )[0][0] start_idx_re_compute = torch.where(audio_attention_mask[i] != 0)[0][0] - assert start_idx == start_idx_re_compute, f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}" + assert ( + start_idx == start_idx_re_compute + ), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}" if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx: - text_input_embeds = text_input_embeds[:audio_embeddings.shape[1] - start_idx] - logging.warning(f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}") - audio_embeddings[i, start_idx:start_idx + text_input_embeds.shape[0]] += text_input_embeds + text_input_embeds = text_input_embeds[ + : audio_embeddings.shape[1] - start_idx + ] + logging.warning( + f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}" + ) + audio_embeddings[ + i, start_idx : start_idx + text_input_embeds.shape[0] + ] += text_input_embeds speech_outputs = self.codec_lm( attention_mask=audio_attention_mask, @@ -369,8 +433,10 @@ class SPEECH_LLM(nn.Module): ) last_hidden_state = speech_outputs.hidden_states[-1].clone() - audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size - audio_logits = audio_logits.contiguous().view(-1, self.codec_lm.config.vocab_size) + audio_logits = self.codec_lm_head(last_hidden_state) # shape, B, T, vocab_size + audio_logits = audio_logits.contiguous().view( + -1, self.codec_lm.config.vocab_size + ) audio_labels = audio_labels.contiguous().view(-1) audio_labels = audio_labels.masked_fill( audio_labels == self.codec_lm.config.pad_token_id, IGNORE_TOKEN_ID @@ -378,7 +444,6 @@ class SPEECH_LLM(nn.Module): codec_loss = self.loss_fct(audio_logits, audio_labels) audio_preds = torch.argmax(audio_logits, -1) - with torch.no_grad(): preds = torch.argmax(model_outputs.logits, -1) acc = compute_accuracy( @@ -392,12 +457,11 @@ class SPEECH_LLM(nn.Module): ignore_label=IGNORE_TOKEN_ID, ) audio_topk_acc = self.audio_accuracy_metric( - audio_logits.detach(), - audio_labels.detach()).item() - + audio_logits.detach(), audio_labels.detach() + ).item() return text_loss, acc, codec_loss, audio_acc, audio_topk_acc - + def decode( self, fbank: torch.Tensor = None, @@ -453,12 +517,12 @@ class SPEECH_LLM(nn.Module): def decode_with_speech_output( self, fbank: torch.Tensor = None, - input_ids: torch.LongTensor = None, # Prompt input_ids - attention_mask: torch.Tensor = None, # Prompt attention_mask + input_ids: torch.LongTensor = None, # Prompt input_ids + attention_mask: torch.Tensor = None, # Prompt attention_mask max_text_new_tokens: int = 1024, - max_speech_new_tokens: int = 1024, # Max length for speech tokens - llm_kwargs: dict = None, # Kwargs for text LLM generate - codec_lm_kwargs: dict = None # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET + max_speech_new_tokens: int = 1024, # Max length for speech tokens + llm_kwargs: dict = None, # Kwargs for text LLM generate + codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET ) -> Tuple[torch.LongTensor, List[List[int]]]: """ Generates text and corresponding speech tokens using the revised logic. @@ -479,16 +543,22 @@ class SPEECH_LLM(nn.Module): the generated speech codec tokens for a batch item. """ assert fbank.shape[0] == 1, "Batch size must be 1 for speech generation." - if not self.codec_lm or not self.speech_token_projector or not self.codec_lm_head: - raise ValueError("codec_lm and associated layers must be initialized to generate speech output.") + if ( + not self.codec_lm + or not self.speech_token_projector + or not self.codec_lm_head + ): + raise ValueError( + "codec_lm and associated layers must be initialized to generate speech output." + ) - device = next(self.parameters()).device # Use model's device + device = next(self.parameters()).device # Use model's device batch_size = fbank.shape[0] # --- 1. Prepare Prompt Embeddings --- encoder_outs = self.encoder(fbank) speech_features = self.encoder_projector(encoder_outs) - speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype + speech_features = speech_features.to(self.llm.dtype) # Ensure matching dtype prompt_embeds = self.llm.get_input_embeddings()(input_ids) @@ -511,12 +581,12 @@ class SPEECH_LLM(nn.Module): "eos_token_id": self.llm.config.eos_token_id, "pad_token_id": self.llm.config.pad_token_id, "num_beams": 1, - "do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed + "do_sample": True, # Typically false for S2ST/S2TT tasks unless exploration needed "top_p": 0.5, "top_k": 20, "repetition_penalty": 1.1, "temperature": 0.7, - **(llm_kwargs or {}) # User-provided kwargs override defaults + **(llm_kwargs or {}), # User-provided kwargs override defaults } text_outputs = self.llm.generate( @@ -525,17 +595,22 @@ class SPEECH_LLM(nn.Module): max_new_tokens=max_text_new_tokens, return_dict_in_generate=True, output_hidden_states=True, - **final_llm_kwargs + **final_llm_kwargs, ) delay_step = 1 - generated_text_ids = text_outputs.sequences # [B, S_full] + generated_text_ids = text_outputs.sequences # [B, S_full] eos_token_id = self.llm.config.eos_token_id - eos_token_embedding = self.llm.get_input_embeddings()(torch.tensor([[eos_token_id]], device=device)) # 1,D - assert generated_text_ids[0, -1] == eos_token_id, f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}" + eos_token_embedding = self.llm.get_input_embeddings()( + torch.tensor([[eos_token_id]], device=device) + ) # 1,D + assert ( + generated_text_ids[0, -1] == eos_token_id + ), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}" thinker_token_embeds_org = [ - token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states + token_hidden_states[0].to(self.llm.device) + for token_hidden_states in text_outputs.hidden_states ] - # shift one for thinker token_embeds, drop the first embeds, and add the eos token + # shift one for thinker token_embeds, drop the first embeds, and add the eos token first_thinker_token_embed = torch.cat( [ thinker_token_embeds_org[0][:, 1:], @@ -544,19 +619,27 @@ class SPEECH_LLM(nn.Module): dim=1, ) - thinker_token_embeds = [first_thinker_token_embed] + thinker_token_embeds_org[2:] + [eos_token_embedding] + thinker_token_embeds = ( + [first_thinker_token_embed] + + thinker_token_embeds_org[2:] + + [eos_token_embedding] + ) thinker_hidden_states = [ - token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states + token_hidden_states[-1].to(self.llm.device) + for token_hidden_states in text_outputs.hidden_states ] # thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1) - thinker_reply_part = [torch.cat( - [ - thinker_hidden_state, - thinker_token_embed, - ], - dim=-1, - ) - for thinker_hidden_state, thinker_token_embed in zip(thinker_hidden_states[1:], thinker_token_embeds[1:]) + thinker_reply_part = [ + torch.cat( + [ + thinker_hidden_state, + thinker_token_embed, + ], + dim=-1, + ) + for thinker_hidden_state, thinker_token_embed in zip( + thinker_hidden_states[1:], thinker_token_embeds[1:] + ) ] thinker_reply_part = torch.cat(thinker_reply_part, dim=1) # thinker_prompt_part = thinker_hidden_states[0] + thinker_token_embeds[0] @@ -568,26 +651,35 @@ class SPEECH_LLM(nn.Module): dim=-1, ) - thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) # [B, S_full, D_codec] - thinker_reply_part = self.speech_token_projector(thinker_reply_part) # [B, S_full, D_codec] - + thinker_prompt_part = self.speech_token_projector( + thinker_prompt_part + ) # [B, S_full, D_codec] + thinker_reply_part = self.speech_token_projector( + thinker_reply_part + ) # [B, S_full, D_codec] thinker_prompt_part_seq_len = thinker_prompt_part.shape[1] talker_input_ids = torch.full( - (batch_size, thinker_prompt_part_seq_len + delay_step + 1), self.codec_lm.config.mask_token_id, dtype=torch.long, device=self.llm.device + (batch_size, thinker_prompt_part_seq_len + delay_step + 1), + self.codec_lm.config.mask_token_id, + dtype=torch.long, + device=self.llm.device, ) - talker_input_ids[:,-1] = self.codec_lm.config.bos_token_id - talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) # [B, S_full, D_codec] + talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id + talker_inputs_embeds = self.codec_lm.get_input_embeddings()( + talker_input_ids + ) # [B, S_full, D_codec] thinker_input_embeds = torch.cat( [ thinker_prompt_part, - thinker_reply_part[:, :delay_step + 1, :], + thinker_reply_part[:, : delay_step + 1, :], ], dim=1, ) talker_inputs_embeds += thinker_input_embeds - thinker_reply_part = thinker_reply_part[:, delay_step + 1:, :] # [B, S_full, D_codec] - + thinker_reply_part = thinker_reply_part[ + :, delay_step + 1 :, : + ] # [B, S_full, D_codec] past_key_values = None # generated_speech_tokens_list = [[] for _ in range(batch_size)] @@ -599,10 +691,14 @@ class SPEECH_LLM(nn.Module): # Get embedding for the *current* input token ID (initially BOS, then generated tokens) # current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec] if t > 0: - talker_inputs_embeds = self.codec_lm.get_input_embeddings()(next_token_ids) # [B, 1, D_codec] + talker_inputs_embeds = self.codec_lm.get_input_embeddings()( + next_token_ids + ) # [B, 1, D_codec] if thinker_reply_part.shape[1] > 0: talker_inputs_embeds += thinker_reply_part[:, :1, :] - thinker_reply_part = thinker_reply_part[:, 1:, :] # Remove the first token for next step + thinker_reply_part = thinker_reply_part[ + :, 1:, : + ] # Remove the first token for next step # # Add the projected text embedding corresponding to the current timestep `t` # if t < text_context_len: # # Text context from the full generated text sequence @@ -611,20 +707,24 @@ class SPEECH_LLM(nn.Module): # else: # # No more text context to add # inputs_embeds = current_speech_embeds - + # Forward pass through codec LM for one step # We provide inputs_embeds directly, bypassing prepare_inputs_for_generation codec_outputs = self.codec_lm( - inputs_embeds=talker_inputs_embeds, # Combined embedding for this step + inputs_embeds=talker_inputs_embeds, # Combined embedding for this step past_key_values=past_key_values, use_cache=True, return_dict=True, output_hidden_states=True, # No attention mask needed here when using past_key_values and single token input ) - last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] # [B, D_codec] #TODO: check shape here + last_token_hidden_state = codec_outputs.hidden_states[-1][ + :, -1, : + ] # [B, D_codec] #TODO: check shape here # Get logits for the *last* token generated in this step - next_token_logits = self.codec_lm_head(last_token_hidden_state) # Use -1 index + next_token_logits = self.codec_lm_head( + last_token_hidden_state + ) # Use -1 index # suppress tokens between 4096:len(vocab)-3 # next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens? next_token_ids = topk_sampling( @@ -634,11 +734,14 @@ class SPEECH_LLM(nn.Module): if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id: break # current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step - past_key_values = codec_outputs.past_key_values # Update KV cache - generated_speech_tokens_list.append(next_token_ids.squeeze(1).cpu().tolist()[0]) + past_key_values = codec_outputs.past_key_values # Update KV cache + generated_speech_tokens_list.append( + next_token_ids.squeeze(1).cpu().tolist()[0] + ) # --- 6. Return Results --- return generated_text_ids, generated_speech_tokens_list + def compute_accuracy(pad_outputs, pad_targets, ignore_label): """Calculate accuracy. Copied from https://github.com/X-LANCE/SLAM-LLM/blob/main/src/slam_llm/utils/metric.py @@ -717,4 +820,4 @@ def top_k_top_p_filtering( 1, sorted_indices, sorted_indices_to_remove ) logits[indices_to_remove] = filter_value - return logits \ No newline at end of file + return logits diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/speech_dataset.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/speech_dataset.py similarity index 99% rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/speech_dataset.py rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/speech_dataset.py index d0a77fd0e..43a4efb5a 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/speech_dataset.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/speech_dataset.py @@ -1,13 +1,12 @@ from typing import Callable, Dict, List, Union import torch -from torch.utils.data.dataloader import DataLoader, default_collate - from lhotse import validate from lhotse.cut import CutSet from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures from lhotse.utils import compute_num_frames, ifnone from lhotse.workarounds import Hdf5MemoryIssueFix +from torch.utils.data.dataloader import DataLoader, default_collate class K2SpeechRecognitionDataset(torch.utils.data.Dataset): diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py similarity index 100% rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/train.py rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/web_demo.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py similarity index 60% rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/web_demo.py rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py index 3155174fb..e33d2437d 100644 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/web_demo.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/web_demo.py @@ -1,26 +1,25 @@ # Modified from https://github.com/QwenLM/Qwen2.5-Omni/blob/main/web_demo.py import io - -import numpy as np -import gradio as gr -import soundfile as sf - -import gradio.processing_utils as processing_utils -import tempfile -from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config -from gradio_client import utils as client_utils - -from argparse import ArgumentParser -import whisper -import torch -from peft import LoraConfig, get_peft_model -from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from model import SPEECH_LLM, EncoderProjector -from train import DEFAULT_SPEECH_TOKEN, add_model_arguments -import sherpa_onnx -from cosyvoice.cli.cosyvoice import CosyVoice import sys -sys.path.append('/workspace/CosyVoice/third_party/Matcha-TTS') +from argparse import ArgumentParser + +import gradio as gr +import gradio.processing_utils as processing_utils +import numpy as np +import sherpa_onnx +import soundfile as sf +import torch +import whisper +from cosyvoice.cli.cosyvoice import CosyVoice +from gradio_client import utils as client_utils +from model import SPEECH_LLM, EncoderProjector +from peft import LoraConfig, get_peft_model +from train import DEFAULT_SPEECH_TOKEN, add_model_arguments +from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward + +# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party +sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") def get_model(params, device="cuda"): @@ -88,7 +87,7 @@ def get_model(params, device="cuda"): codec_lm = AutoModelForCausalLM.from_config( config=config, attn_implementation=attn_implementation, - torch_dtype=torch.float16 + torch_dtype=torch.float16, ) codec_lm.resize_token_embeddings(codec_vocab_size) codec_lm.vocab_size = codec_vocab_size @@ -102,12 +101,10 @@ def get_model(params, device="cuda"): llm, encoder_projector, codec_lm, - codec_lm_padding_side= "left" if params.use_flash_attn else "right", + codec_lm_padding_side="left" if params.use_flash_attn else "right", ) - checkpoint = torch.load( - f"{params.checkpoint_path}", map_location="cpu" - ) + checkpoint = torch.load(f"{params.checkpoint_path}", map_location="cpu") model.load_state_dict(checkpoint, strict=False) model.to(device) @@ -122,27 +119,37 @@ def audio_decode_cosyvoice(audio_tokens, codec_decoder): Args: audio_tokens (list): List of audio tokens to be processed. codec_decoder: Codec decoder for generating audio. - + Returns: torch.Tensor: Generated audio waveform. """ - flow_embedding = codec_decoder.frontend.spk2info['中文女']['embedding'] + flow_embedding = codec_decoder.frontend.spk2info["中文女"]["embedding"] flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32) prompt_speech_feat = torch.zeros(1, 0, 80) - tts_mel, _ = codec_decoder.model.flow.inference(token=audio_tokens.to(codec_decoder.model.device), - token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(codec_decoder.model.device), - prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device), - prompt_token_len=torch.tensor([flow_prompt_speech_token.shape[1]], dtype=torch.int32).to(codec_decoder.model.device), - prompt_feat=prompt_speech_feat.to(codec_decoder.model.device), - prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(codec_decoder.model.device), - embedding=flow_embedding.to(codec_decoder.model.device), - flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device),) + tts_mel, _ = codec_decoder.model.flow.inference( + token=audio_tokens.to(codec_decoder.model.device), + token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to( + codec_decoder.model.device + ), + prompt_token=flow_prompt_speech_token.to(codec_decoder.model.device), + prompt_token_len=torch.tensor( + [flow_prompt_speech_token.shape[1]], dtype=torch.int32 + ).to(codec_decoder.model.device), + prompt_feat=prompt_speech_feat.to(codec_decoder.model.device), + prompt_feat_len=torch.tensor( + [prompt_speech_feat.shape[1]], dtype=torch.int32 + ).to(codec_decoder.model.device), + embedding=flow_embedding.to(codec_decoder.model.device), + flow_cache=torch.zeros(1, 80, 0, 2).to(codec_decoder.model.device), + ) - - audio_hat, _ = codec_decoder.model.hift.inference(speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)) + audio_hat, _ = codec_decoder.model.hift.inference( + speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) + ) return audio_hat + def preprocess( messages, tokenizer, @@ -178,28 +185,14 @@ def preprocess( attention_mask = input_ids.ne(tokenizer.pad_token_id) return input_ids, attention_mask - - -def _launch_demo(args, model, tokenizer, token2wav_model, asr_model): + +def _launch_demo(args, model, tokenizer, token2wav_model, asr_model): def format_history(history: list): messages = [] for item in history: if isinstance(item["content"], str): - messages.append({"role": item['role'], "content": item['content']}) - # elif item["role"] == "user" and (isinstance(item["content"], list) or - # isinstance(item["content"], tuple)): - # file_path = item["content"][0] - # # TODO: check if the file_path's transcript is already in the history - # mime_type = client_utils.get_mimetype(file_path) - # if mime_type.startswith("audio"): - # messages.append({ - # "role": - # item['role'], - # "content": item["content"][1] # append audio transcript here - # }) - print('predict history: ', messages) - # messages = messages[-2:] # TODO: WAR: add history later + messages.append({"role": item["role"], "content": item["content"]}) return messages def decode( @@ -217,9 +210,8 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model): dtype = torch.float32 device = model.llm.device - feature = feature.to(device, dtype=dtype)#.transpose(1, 2) - # assert feature.shape[2] == 80 - + feature = feature.to(device, dtype=dtype) + input_ids, attention_mask = preprocess([messages], tokenizer) generated_ids, audio_tokens = model.decode_with_speech_output( @@ -227,26 +219,21 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model): ) hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - # print('hyps: ', hyps, 23333333333333333333333333) + yield {"type": "text", "data": hyps[0]} - # yield {"type": "text", "data": hyps} audio_tokens = [token for token in audio_tokens if token < 4096] audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0) audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model) - audio = audio_hat.squeeze(0).cpu().numpy() - # sf.write(f'{wav_name}.wav', audio_hat.squeeze(0).cpu().numpy(), 22050) + audio = audio_hat.squeeze(0).cpu().numpy() audio = np.array(audio * 32767).astype(np.int16) - # yield {"type": "audio", "data": (22050, audio)} - # with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: - # sf.write(tmpfile.name, audio, 22050, format="WAV") - # audio_path = tmpfile.name wav_io = io.BytesIO() sf.write(wav_io, audio, samplerate=22050, format="WAV") wav_io.seek(0) wav_bytes = wav_io.getvalue() audio_path = processing_utils.save_bytes_to_cache( - wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE) + wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE + ) yield {"type": "audio", "data": audio_path} @@ -259,25 +246,27 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model): gr.update(visible=True), # stop_btn ) print(2333, history, audio) - history.append({"role": "user", "content": (audio,)}) - history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}) + history.append({"role": "user", "content": (audio,)}) + history.append({"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}) history.append({"role": "assistant", "content": ""}) - formatted_history = format_history(history=history) # only keep string text format + formatted_history = format_history( + history=history + ) # only keep string text format assert audio is not None audio_transcript = get_transcript( audio, asr_model, ) - print('audio_transcript: ', audio_transcript) history[-2]["content"] = audio_transcript fbank = whisper.log_mel_spectrogram(audio, device=model.llm.device) fbank = fbank.unsqueeze(0) assert fbank.ndim == 3 - # history.append({"role": "assistant", "content": ""}) - for chunk in decode(model, token2wav_model, tokenizer, fbank, formatted_history): + for chunk in decode( + model, token2wav_model, tokenizer, fbank, formatted_history + ): if chunk["type"] == "text": history[-1]["content"] = chunk["data"] yield ( @@ -287,10 +276,9 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model): gr.update(visible=True), # stop_btn ) if chunk["type"] == "audio": - history.append({ - "role": "assistant", - "content": gr.Audio(chunk["data"]) - }) + history.append( + {"role": "assistant", "content": gr.Audio(chunk["data"])} + ) # Final yield yield ( @@ -304,8 +292,7 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model): with gr.Tab("Online"): with gr.Row(): with gr.Column(scale=1): - microphone = gr.Audio(sources=['microphone'], - type="filepath") + microphone = gr.Audio(sources=["microphone"], type="filepath") submit_btn = gr.Button("Submit", variant="primary") stop_btn = gr.Button("Stop", visible=False) clear_btn = gr.Button("Clear History") @@ -315,64 +302,80 @@ def _launch_demo(args, model, tokenizer, token2wav_model, asr_model): def clear_history(): return [], gr.update(value=None) - submit_event = submit_btn.click(fn=media_predict, - inputs=[ - microphone, - media_chatbot, - ], - outputs=[ - microphone, - media_chatbot, submit_btn, - stop_btn - ]) + submit_event = submit_btn.click( + fn=media_predict, + inputs=[ + microphone, + media_chatbot, + ], + outputs=[microphone, media_chatbot, submit_btn, stop_btn], + ) stop_btn.click( - fn=lambda: - (gr.update(visible=True), gr.update(visible=False)), + fn=lambda: (gr.update(visible=True), gr.update(visible=False)), inputs=None, outputs=[submit_btn, stop_btn], cancels=[submit_event], - queue=False) - clear_btn.click(fn=clear_history, - inputs=None, - outputs=[media_chatbot, microphone]) + queue=False, + ) + clear_btn.click( + fn=clear_history, inputs=None, outputs=[media_chatbot, microphone] + ) - demo.queue(default_concurrency_limit=100, max_size=100).launch(max_threads=100, - ssr_mode=False, - share=args.share, - inbrowser=args.inbrowser, - server_port=args.server_port, - server_name=args.server_name,) + demo.queue(default_concurrency_limit=100, max_size=100).launch( + max_threads=100, + ssr_mode=False, + share=args.share, + inbrowser=args.inbrowser, + server_port=args.server_port, + server_name=args.server_name, + ) def _get_args(): parser = ArgumentParser() - parser.add_argument('--checkpoint-path', - type=str, - default=None, - help='Checkpoint name or path, default to %(default)r') - parser.add_argument('--token2wav-path', - type=str, - default=None, - help='Token2Wav path, default to %(default)r') - parser.add_argument('--asr-model-dir', - type=str, - default=None, - help='ASR model dir, default to %(default)r') - parser.add_argument('--flash-attn2', - action='store_true', - default=False, - help='Enable flash_attention_2 when loading the model.') - parser.add_argument('--share', - action='store_true', - default=False, - help='Create a publicly shareable link for the interface.') - parser.add_argument('--inbrowser', - action='store_true', - default=False, - help='Automatically launch the interface in a new tab on the default browser.') - parser.add_argument('--server-port', type=int, default=8001, help='Demo server port.') - parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.') + parser.add_argument( + "--checkpoint-path", + type=str, + default=None, + help="Checkpoint name or path, default to %(default)r", + ) + parser.add_argument( + "--token2wav-path", + type=str, + default=None, + help="Token2Wav path, default to %(default)r", + ) + parser.add_argument( + "--asr-model-dir", + type=str, + default=None, + help="ASR model dir, default to %(default)r", + ) + parser.add_argument( + "--flash-attn2", + action="store_true", + default=False, + help="Enable flash_attention_2 when loading the model.", + ) + parser.add_argument( + "--share", + action="store_true", + default=False, + help="Create a publicly shareable link for the interface.", + ) + parser.add_argument( + "--inbrowser", + action="store_true", + default=False, + help="Automatically launch the interface in a new tab on the default browser.", + ) + parser.add_argument( + "--server-port", type=int, default=8001, help="Demo server port." + ) + parser.add_argument( + "--server-name", type=str, default="127.0.0.1", help="Demo server name." + ) add_model_arguments(parser) args = parser.parse_args() return args @@ -401,6 +404,7 @@ def read_wave(wave_filename: str): return samples_float32, sample_rate + def get_transcript(audio_path, recognizer): samples, sample_rate = read_wave(audio_path) s = recognizer.create_stream() @@ -408,10 +412,13 @@ def get_transcript(audio_path, recognizer): recognizer.decode_streams([s]) return s.result.text + if __name__ == "__main__": args = _get_args() model, tokenizer = get_model(args) - token2wav = CosyVoice(args.token2wav_path, load_jit=False, load_trt=False, fp16=False) + token2wav = CosyVoice( + args.token2wav_path, load_jit=False, load_trt=False, fp16=False + ) asr_model = sherpa_onnx.OfflineRecognizer.from_paraformer( paraformer=f"{args.asr_model_dir}/model.int8.onnx", @@ -423,4 +430,4 @@ if __name__ == "__main__": debug=False, ) - _launch_demo(args, model, tokenizer, token2wav, asr_model) \ No newline at end of file + _launch_demo(args, model, tokenizer, token2wav, asr_model) diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/whisper_encoder_forward_monkey_patch.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/whisper_encoder_forward_monkey_patch.py similarity index 100% rename from egs/speech_llm/SPEECH2SPEECH/slam_omni/whisper_encoder_forward_monkey_patch.py rename to egs/speech_llm/SPEECH2SPEECH/qwen_omni/whisper_encoder_forward_monkey_patch.py