diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py index f1b25d3e6..5b5628f74 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py @@ -68,7 +68,6 @@ from transformers import ( Qwen2Config, Qwen2ForCausalLM, ) - from utils import ( # filter_uneven_sized_batch, AttributeDict, MetricsTracker, @@ -306,8 +305,7 @@ def get_params() -> AttributeDict: def extract_text_and_speech_token( - batch: dict, - enable_speech_output: bool + batch: dict, enable_speech_output: bool ) -> Tuple[List[Dict[str, str]], Optional[List[Any]]]: """ Extracts messages and speech tokens from a batch based on the dataset format. @@ -342,28 +340,34 @@ def extract_text_and_speech_token( # The 'prompt_template' argument to the function seems unused if we determine it here. # For now, I will proceed assuming the internal logic dictates the template. # If the function argument `prompt_template` was meant to be the default, this logic would need adjustment. - current_prompt_template = "speech_qa" # Default value for prompt_template for the current item + current_prompt_template = ( + "speech_qa" # Default value for prompt_template for the current item + ) target = answers[i] - message_list_item = [] - + message_list_item = [] + custom_data = batch["supervisions"]["cut"][i].custom - if 'round' in custom_data: + if "round" in custom_data: # slam_omni format dataset # For 'round' type, the current interaction's user prompt will use current_prompt_template ("speech_qa") current_question_with_history = custom_data["question"] total_round = custom_data["round"] - history_context = current_question_with_history.rsplit(":", 1)[0].strip() + history_context = current_question_with_history.rsplit(":", 1)[ + 0 + ].strip() if total_round > 1: history_question_answer = history_context.split("USER:") - history_question_answer = [item for item in history_question_answer if item] + history_question_answer = [ + item for item in history_question_answer if item + ] for j in range(total_round - 1): question_answer = history_question_answer[j].split("ASSISTANT:") message_list_item += [ {"role": "user", "content": question_answer[0].strip()}, {"role": "assistant", "content": question_answer[1].strip()}, ] - elif 'continuation' in custom_data: + elif "continuation" in custom_data: # see https://huggingface.co/datasets/fixie-ai/librispeech_asr ASR_PROBABILITY = 0.3 if random.random() < ASR_PROBABILITY: @@ -382,6 +386,7 @@ def extract_text_and_speech_token( return messages, speech_tokens + def preprocess( messages, tokenizer: transformers.PreTrainedTokenizer, @@ -432,12 +437,11 @@ def preprocess( attention_mask = input_ids.ne(tokenizer.pad_token_id) return input_ids, attention_mask, target_ids + def process_batch_text_continuation(batch: dict): messages = [] transcripts = batch["supervisions"]["text"] - continuations = [ - cut.custom["continuation"] for cut in batch["supervisions"]["cut"] - ] + continuations = [cut.custom["continuation"] for cut in batch["supervisions"]["cut"]] for i in range(len(transcripts)): message = [ { @@ -449,6 +453,7 @@ def process_batch_text_continuation(batch: dict): messages.append(message) return messages + def preprocess_teacher( messages, tokenizer: transformers.PreTrainedTokenizer, @@ -827,27 +832,29 @@ def get_model(params): if not params.unfreeze_llm: for name, param in llm.named_parameters(): param.requires_grad = False - else: - if params.use_lora: - lora_config = LoraConfig( - r=64, - lora_alpha=16, - target_modules=[ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "up_proj", - "gate_proj", - "down_proj", - ], - lora_dropout=0.05, - task_type="CAUSAL_LM", - ) - llm = get_peft_model(llm, lora_config) - llm.print_trainable_parameters() + if params.use_lora: + lora_config = LoraConfig( + r=64, + lora_alpha=16, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "up_proj", + "gate_proj", + "down_proj", + ], + lora_dropout=0.05, + task_type="CAUSAL_LM", + ) + llm = get_peft_model(llm, lora_config) + llm.print_trainable_parameters() llm.config.pad_token_id = tokenizer.pad_token_id + llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") + llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") + llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( DEFAULT_SPEECH_TOKEN ) @@ -884,7 +891,9 @@ def get_model(params): elif params.speech_tokenizer_type == "cosyvoice1": codec_vocab_size = 4096 + 4 else: - raise ValueError(f"Unknown speech tokenizer type: {params.speech_tokenizer_type}") + raise ValueError( + f"Unknown speech tokenizer type: {params.speech_tokenizer_type}" + ) config = Qwen2Config( vocab_size=codec_vocab_size, @@ -921,10 +930,14 @@ def get_model(params): if params.pretrained_model_path or params.last_stage_model_path: if params.pretrained_model_path is None: checkpoint = torch.load(params.last_stage_model_path, map_location="cpu") - missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + missing_keys, unexpected_keys = model.load_state_dict( + checkpoint, strict=False + ) else: checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") - missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + missing_keys, unexpected_keys = model.load_state_dict( + checkpoint, strict=False + ) # set params.batch_idx_train according to the checkpoint name if "checkpoint-" in params.pretrained_model_path: params.batch_idx_train = int( @@ -940,6 +953,7 @@ def get_model(params): return model, tokenizer + def run(rank, world_size, args): """ Args: