diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py index acdfb4f2c..a52f84b0c 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py @@ -413,6 +413,8 @@ class AsrDataModule: ultrachat_cuts = load_manifest_lazy( self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz" ) + VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000) + ultrachat_cuts = ultrachat_cuts.resample(16000) return CutSet.mux( VoiceAssistant_cuts, ultrachat_cuts, @@ -427,6 +429,7 @@ class AsrDataModule: VoiceAssistant_cuts = load_manifest_lazy( self.args.manifest_dir / "cuts_voice_assistant.00000.jsonl.gz" ) + VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000) return VoiceAssistant_cuts @lru_cache() @@ -435,6 +438,7 @@ class AsrDataModule: VoiceAssistant_cuts = load_manifest_lazy( self.args.manifest_dir / "cuts_voice_assistant_small.00000.jsonl.gz" ) + VoiceAssistant_cuts = VoiceAssistant_cuts.resample(16000) return {"test": VoiceAssistant_cuts} @lru_cache() @@ -482,36 +486,36 @@ class AsrDataModule: librispeech_clean_100_cuts = CutSet.from_huggingface_dataset( librispeech_clean_100, - audio_key=self.args.audio_key, - text_key=self.args.text_key, + audio_key="audio", + text_key="text", ) librispeech_other_cuts = CutSet.from_huggingface_dataset( librispeech_other, - audio_key=self.args.audio_key, - text_key=self.args.text_key, + audio_key="audio", + text_key="text", ) librispeech_clean_360_cuts = CutSet.from_huggingface_dataset( librispeech_clean_360, - audio_key=self.args.audio_key, - text_key=self.args.text_key, + audio_key="audio", + text_key="text", ) gigaspeech_cuts = CutSet.from_huggingface_dataset( - gigaspeech, audio_key=self.args.audio_key, text_key=self.args.text_key + gigaspeech, audio_key="audio", text_key="text" ) people_speech_clean_cuts = CutSet.from_huggingface_dataset( people_speech_clean, - audio_key=self.args.audio_key, - text_key=self.args.text_key, + audio_key="audio", + text_key="text", ) people_speech_dirty_sa_cuts = CutSet.from_huggingface_dataset( people_speech_dirty_sa, - audio_key=self.args.audio_key, - text_key=self.args.text_key, + audio_key="audio", + text_key="text", ) return CutSet.mux( @@ -540,8 +544,8 @@ class AsrDataModule: ) librispeech_clean_valid_cuts = CutSet.from_huggingface_dataset( librispeech_clean_valid, - audio_key=self.args.audio_key, - text_key=self.args.text_key, + audio_key="audio", + text_key="text", ) return librispeech_clean_valid_cuts @@ -567,20 +571,20 @@ class AsrDataModule: librispeech_clean_100_cuts = CutSet.from_huggingface_dataset( librispeech_clean_100, - audio_key=self.args.audio_key, - text_key=self.args.text_key, + audio_key="audio", + text_key="text", ) librispeech_other_cuts = CutSet.from_huggingface_dataset( librispeech_other, - audio_key=self.args.audio_key, - text_key=self.args.text_key, + audio_key="audio", + text_key="text", ) librispeech_clean_360_cuts = CutSet.from_huggingface_dataset( librispeech_clean_360, - audio_key=self.args.audio_key, - text_key=self.args.text_key, + audio_key="audio", + text_key="text", ) return CutSet.mux( @@ -603,7 +607,148 @@ class AsrDataModule: ) gigaspeech_cuts = CutSet.from_huggingface_dataset( - gigaspeech, audio_key=self.args.audio_key, text_key=self.args.text_key + gigaspeech, audio_key="audio", text_key="text" ) return gigaspeech_cuts + + @lru_cache() + def train_cuts_instruct_s2s(self) -> CutSet: + logging.info("About to get train cuts") + if self.args.huggingface_dataset_path_or_name is not None: + data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K" + else: + data_path = "yuekai/InstructS2S-200K" + # 148_688 + instruct_s2s_train = load_dataset( + data_path, split="train", streaming=True + ) + + instruct_s2s_train_cuts = CutSet.from_huggingface_dataset( + instruct_s2s_train, + audio_key="question_audio", + text_key="answer", + ) + + instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000) + + return instruct_s2s_train_cuts + + @lru_cache() + def train_cuts_en_speech2speech(self) -> CutSet: + logging.info("About to get train cuts") + VoiceAssistant_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz" + ) + ultrachat_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz" + ) + + if self.args.huggingface_dataset_path_or_name is not None: + data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K" + else: + data_path = "yuekai/InstructS2S-200K" + # 148_688 + instruct_s2s_train = load_dataset( + data_path, split="train", streaming=True + ) + + instruct_s2s_train_cuts = CutSet.from_huggingface_dataset( + instruct_s2s_train, + audio_key="question_audio", + text_key="answer", + ) + + instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000) + + + return CutSet.mux( + VoiceAssistant_cuts, + ultrachat_cuts, + instruct_s2s_train_cuts, + weights=[ + len(VoiceAssistant_cuts), + len(ultrachat_cuts), + 423_000, + ], + ) + + @lru_cache() + def train_cuts_en_speech2speech_librispeech(self) -> CutSet: + logging.info("About to get train cuts") + VoiceAssistant_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_voice_assistant_00001-00049.jsonl.gz" + ) + ultrachat_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_ultrachat_train.jsonl.gz" + ) + + if self.args.huggingface_dataset_path_or_name is not None: + data_path = self.args.huggingface_dataset_path_or_name + "/InstructS2S-200K" + else: + data_path = "yuekai/InstructS2S-200K" + # 148_688 + instruct_s2s_train = load_dataset( + data_path, split="train", streaming=True + ) + + instruct_s2s_train_cuts = CutSet.from_huggingface_dataset( + instruct_s2s_train, + audio_key="question_audio", + text_key="answer", + ) + + instruct_s2s_train_cuts = instruct_s2s_train_cuts.resample(16000) + + if self.args.huggingface_dataset_path_or_name is not None: + librispeech_path = self.args.huggingface_dataset_path_or_name + "/librispeech_asr" + else: + librispeech_path = "fixie-ai/librispeech_asr" + # 148_688 + librispeech_other = load_dataset( + librispeech_path, "other", split="train.500", streaming=True + ) + # 104_014 + librispeech_clean_360 = load_dataset( + librispeech_path, "clean", split="train.360", streaming=True + ) + # 28_539 + librispeech_clean_100 = load_dataset( + librispeech_path, "clean", split="train.100", streaming=True + ) + + librispeech_clean_100_cuts = CutSet.from_huggingface_dataset( + librispeech_clean_100, + audio_key="audio", + text_key="text", + ) + + librispeech_other_cuts = CutSet.from_huggingface_dataset( + librispeech_other, + audio_key="audio", + text_key="text", + ) + + librispeech_clean_360_cuts = CutSet.from_huggingface_dataset( + librispeech_clean_360, + audio_key="audio", + text_key="text", + ) + + + return CutSet.mux( + librispeech_other_cuts, + VoiceAssistant_cuts, + ultrachat_cuts, + librispeech_clean_360_cuts, + instruct_s2s_train_cuts, + librispeech_clean_100_cuts, + weights=[ + 148688, + len(VoiceAssistant_cuts), + len(ultrachat_cuts), + 104014, + 423_000, + 28539, + ], + ) \ No newline at end of file diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py index e65cc7829..9554d85e4 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py @@ -193,6 +193,13 @@ def get_parser(): """, ) + parser.add_argument( + "--last-stage-model-path", + type=str, + default=None, + help="""The path to the last stage model if it is not None. Training will start from this model. + """, + ) parser.add_argument( "--sampler-state-dict-path", type=str, @@ -229,13 +236,6 @@ def get_parser(): help="Whether to unfreeze speech adaptor during training.", ) - parser.add_argument( - "--prompt-template", - type=str, - default="speech_qa", - help="The prompt template to use.", - ) - parser.add_argument( "--dataset", type=str, @@ -300,7 +300,6 @@ def get_params() -> AttributeDict: def extract_text_and_speech_token( batch: dict, - prompt_template: str, enable_speech_output: bool ) -> Tuple[List[Dict[str, str]], Optional[List[Any]]]: """ @@ -325,54 +324,54 @@ def extract_text_and_speech_token( answers = batch["supervisions"]["text"] batch_size = len(answers) - if prompt_template == "speech_qa": - for i in range(batch_size): - message_list_item = [] - if 'round' in batch["supervisions"]["cut"][i].custom: - # slam_omni format dataset - current_question_with_history = batch["supervisions"]["cut"][i].custom["question"] - total_round = batch["supervisions"]["cut"][i].custom["round"] - history_context = current_question_with_history.rsplit(":", 1)[0].strip() - if total_round > 1: - history_question_answer = history_context.split("USER:") - history_question_answer = [item for item in history_question_answer if item] - for j in range(total_round - 1): - question_answer = history_question_answer[j].split("ASSISTANT:") - message_list_item += [ - {"role": "user", "content": question_answer[0].strip()}, - {"role": "assistant", "content": question_answer[1].strip()}, - ] - message_list_item += [ - {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, - {"role": "assistant", "content": answers[i]}, - ] - messages.append(message_list_item) + prompt_template_dict = { + "speech_qa": f"{DEFAULT_SPEECH_TOKEN}", + "speech_continuation": f"Continue the following text using less than 50 words:\\n\\n{DEFAULT_SPEECH_TOKEN}", + "asr": f"Transcribe the following audio into text:\\n\\n{DEFAULT_SPEECH_TOKEN}", + } - elif prompt_template == "speech_continuation": - # speech_tokens remains None - for i in range(batch_size): - message_list_item = [ - { - "role": "user", - "content": f"Continue the following text using less than 50 words:\\n\\n{DEFAULT_SPEECH_TOKEN}", - }, - {"role": "assistant", "content": answers[i]}, - ] - messages.append(message_list_item) + for i in range(batch_size): + # Initialize prompt_template with the original default. + # The 'prompt_template' argument to the function seems unused if we determine it here. + # For now, I will proceed assuming the internal logic dictates the template. + # If the function argument `prompt_template` was meant to be the default, this logic would need adjustment. + current_prompt_template = "speech_qa" # Default value for prompt_template for the current item + target = answers[i] + message_list_item = [] + + custom_data = batch["supervisions"]["cut"][i].custom - elif prompt_template == "asr": - # speech_tokens remains None - for i in range(batch_size): - message_list_item = [ - { - "role": "user", - "content": f"Transcribe the following audio into text:\\n\\n{DEFAULT_SPEECH_TOKEN}", - }, - {"role": "assistant", "content": answers[i]}, - ] - messages.append(message_list_item) - else: - raise ValueError(f"Unknown prompt template: {prompt_template}") + if 'round' in custom_data: + # slam_omni format dataset + # For 'round' type, the current interaction's user prompt will use current_prompt_template ("speech_qa") + current_question_with_history = custom_data["question"] + total_round = custom_data["round"] + history_context = current_question_with_history.rsplit(":", 1)[0].strip() + if total_round > 1: + history_question_answer = history_context.split("USER:") + history_question_answer = [item for item in history_question_answer if item] + for j in range(total_round - 1): + question_answer = history_question_answer[j].split("ASSISTANT:") + message_list_item += [ + {"role": "user", "content": question_answer[0].strip()}, + {"role": "assistant", "content": question_answer[1].strip()}, + ] + elif 'continuation' in custom_data: + # see https://huggingface.co/datasets/fixie-ai/librispeech_asr + ASR_PROBABILITY = 0.3 + if random.random() < ASR_PROBABILITY: + current_prompt_template = "asr" + else: + current_prompt_template = "speech_continuation" + target = custom_data["continuation"] + else: + # single-round, speech2speech conversation data + pass + message_list_item += [ + {"role": "user", "content": prompt_template_dict[current_prompt_template]}, + {"role": "assistant", "content": target}, + ] + messages.append(message_list_item) return messages, speech_tokens @@ -428,14 +427,17 @@ def preprocess( def process_batch_text_continuation(batch: dict): messages = [] - for i in range(len(batch["supervisions"]["text"])): - transcript = batch["supervisions"]["cut"][i].custom["text"] + transcripts = batch["supervisions"]["text"] + continuations = [ + cut.custom["continuation"] for cut in batch["supervisions"]["cut"] + ] + for i in range(len(transcripts)): message = [ { "role": "user", - "content": f"Continue the following text using less than 50 words:\n\n{transcript}{DEFAULT_SPEECH_TOKEN}", + "content": f"Continue the following text using less than 50 words:\n\n{transcripts[i]}{DEFAULT_SPEECH_TOKEN}", }, - {"role": "assistant", "content": batch["supervisions"]["text"][i]}, + {"role": "assistant", "content": continuations[i]}, ] messages.append(message) return messages @@ -532,7 +534,7 @@ def compute_loss( # WAR: TODO FIXME merge process_batch_slam_omni and process_batch_vocalnet messages, answer_cosyvoice_speech_token = extract_text_and_speech_token( - batch, params.prompt_template, params.enable_speech_output + batch, params.enable_speech_output ) input_ids, attention_mask, target_ids = preprocess(messages, tokenizer) @@ -550,7 +552,6 @@ def compute_loss( labels=target_ids.to(device), ) elif params.loss_type == "kl_div": - assert params.prompt_template == "speech_continuation" messages_text = process_batch_text_continuation(batch) ( teacher_input_ids, @@ -942,15 +943,18 @@ def run(rank, world_size, args): teacher_llm=teacher_llm, ) - if params.pretrained_model_path: - checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") - missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) - # set params.batch_idx_train according to the checkpoint name - if "checkpoint-" in params.pretrained_model_path: - params.batch_idx_train = int( - params.pretrained_model_path.split("-")[-1].split("/")[0] - ) - + if params.pretrained_model_path or params.last_stage_model_path: + if params.pretrained_model_path is None: + checkpoint = torch.load(params.last_stage_model_path, map_location="cpu") + missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + else: + checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") + missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + # set params.batch_idx_train according to the checkpoint name + if "checkpoint-" in params.pretrained_model_path: + params.batch_idx_train = int( + params.pretrained_model_path.split("-")[-1].split("/")[0] + ) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -999,6 +1003,12 @@ def run(rank, world_size, args): f"Exclude cut with ID {c.id} from training. Duration: {c.duration}, lenth: {codec_len}" ) return False + if "question" in c.custom: + if len(c.custom["question"]) > 1200: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. question length: {len(c.custom['question'])}" + # ) + return False return True if params.dataset == "slam_omni_belle": @@ -1007,6 +1017,12 @@ def run(rank, world_size, args): elif params.dataset == "vocalnet_ultrachat_voiceassistant": train_cuts = data_module.train_cuts_en_vocalnet() valid_cuts = data_module.valid_cuts_en_vocalnet() + elif params.dataset == "vocalnet_ultrachat_voiceassistant_instruct_s2s": + train_cuts = data_module.train_cuts_en_speech2speech() + valid_cuts = data_module.valid_cuts_en_vocalnet() + elif params.dataset == "vocalnet_ultrachat_voiceassistant_instruct_s2s_librispeech": + train_cuts = data_module.train_cuts_en_speech2speech_librispeech() + valid_cuts = data_module.valid_cuts_en_vocalnet() elif params.dataset == "ultravox_multi_en": train_cuts = data_module.train_cuts_ultravox() valid_cuts = data_module.valid_cuts_ultravox()