diff --git a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py index 54a8983df..5cda487e3 100755 --- a/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py +++ b/egs/speech_llm/SPEECH2SPEECH/slam_omni/decode.py @@ -60,7 +60,7 @@ from data_module import AsrDataModule from lhotse.cut import Cut from model import SPEECH_LLM, EncoderProjector -from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from peft import LoraConfig, get_peft_model from train import DEFAULT_SPEECH_TOKEN from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward @@ -70,10 +70,164 @@ from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, - str2bool, write_error_stats, + average_checkpoints, ) +def get_model(params, device): + """Load and prepare the speech-to-speech model.""" + if params.remove_whisper_encoder_input_length_restriction: + replace_whisper_encoder_forward() + + whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") + speech_encoder = whisper_model.encoder + speech_encoder_dim = whisper_model.dims.n_audio_state + tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) + + if params.use_flash_attn: + attn_implementation = "flash_attention_2" + # torch_dtype=torch.bfloat16 FIX ME + torch_dtype = torch.float16 + tokenizer.padding_side = "left" + + else: + attn_implementation = "eager" + torch_dtype = torch.float16 + tokenizer.padding_side = "right" + + llm = AutoModelForCausalLM.from_pretrained( + params.llm_path_or_name, + attn_implementation=attn_implementation, + torch_dtype=torch_dtype, + ) + if params.use_lora: + lora_config = LoraConfig( + r=64, + lora_alpha=16, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "up_proj", + "gate_proj", + "down_proj", + ], + task_type="CAUSAL_LM", + ) + llm = get_peft_model(llm, lora_config) + llm.print_trainable_parameters() + + special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]} + tokenizer.add_special_tokens(special_tokens_dict) + llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") + llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") + llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + + llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( + DEFAULT_SPEECH_TOKEN + ) + + encoder_projector = EncoderProjector( + speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate + ) + + if params.enable_speech_output: + # Determine attn_implementation and torch_dtype based on use_flash_attn + if params.use_flash_attn: + attn_implementation = "flash_attention_2" + torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported + else: + attn_implementation = "eager" + torch_dtype = torch.float16 + + # codec_lm = AutoModelForCausalLM.from_pretrained( + # params.llm_path_or_name, + # attn_implementation=attn_implementation, + # torch_dtype=torch_dtype, + # ) + codec_vocab_size = 4096 + 4 + config = Qwen2Config( + vocab_size=codec_vocab_size, + hidden_size=1024, + num_hidden_layers=12, + num_attention_heads=16, + num_key_value_heads=16, + intermediate_size=2048, + max_position_embeddings=4096, + ) + # codec_lm = Qwen2ForCausalLM(config=config) + # Pass attn_implementation and torch_dtype to the constructor + # Use AutoModelForCausalLM.from_config for more generality + codec_lm = AutoModelForCausalLM.from_config( + config=config, + attn_implementation=attn_implementation, + torch_dtype=torch_dtype + ) + # cosyvoice2_token_size = 6561 + codec_lm.resize_token_embeddings(codec_vocab_size) + codec_lm.vocab_size = codec_vocab_size + codec_lm.config.pad_token_id = codec_vocab_size - 1 + codec_lm.config.eos_token_id = codec_vocab_size - 2 + codec_lm.config.bos_token_id = codec_vocab_size - 3 + codec_lm.config.mask_token_id = codec_vocab_size - 4 + # if params.use_lora: + # lora_config = LoraConfig( + # r=64, + # lora_alpha=16, + # target_modules=[ + # "q_proj", + # "k_proj", + # "v_proj", + # "o_proj", + # "up_proj", + # "gate_proj", + # "down_proj", + # ], + # lora_dropout=0.05, + # task_type="CAUSAL_LM", + # ) + # codec_lm = get_peft_model(codec_lm, lora_config) + # codec_lm.print_trainable_parameters() + else: + codec_lm = None + + model = SPEECH_LLM( + speech_encoder, + llm, + encoder_projector, + codec_lm, + codec_lm_padding_side= "left" if params.use_flash_attn else "right", + ) + + if params.avg > 1: + start = params.epoch - params.avg + 1 + assert start >= 1, start + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + assert "model" not in checkpoint + # deepspeed converted checkpoint only contains model state_dict + filenames = [ + f"{params.exp_dir}/epoch-{epoch}.pt" + for epoch in range(start, params.epoch + 1) + ] + avg_checkpoint = average_checkpoints(filenames) + model.load_state_dict(avg_checkpoint, strict=False) + + filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save(avg_checkpoint, filename) + else: + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + model.load_state_dict(checkpoint, strict=False) + + model.to(device) + model.eval() + return model, tokenizer + + def average_checkpoints( filenames: List[Path], device: torch.device = torch.device("cpu") ) -> dict: @@ -171,13 +325,6 @@ def get_parser(): help="The experiment dir", ) - parser.add_argument( - "--remove-whisper-encoder-input-length-restriction", - type=str2bool, - default=True, - help="replace whisper encoder forward method to remove input length restriction", - ) - # parser.add_argument( # "--dataset", # type=str, @@ -321,7 +468,7 @@ def decode_one_batch( with open(speech_token_file_name, 'w') as f: # save_path = params.exp_dir / f"speech_output/{cut_id}.wav" #torchaudio.save(save_path, speech_output.cpu(), 16000) - print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}") + # print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}") save_str = " ".join([str(i) for i in generated_speech_output]) f.write(f"{cut_id}|{save_str}\n") @@ -509,155 +656,8 @@ def main(): logging.info(f"device: {device}") - if params.remove_whisper_encoder_input_length_restriction: - replace_whisper_encoder_forward() + model, tokenizer = get_model(params, device) - whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") - speech_encoder = whisper_model.encoder - speech_encoder_dim = whisper_model.dims.n_audio_state - tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) - - if params.use_flash_attn: - attn_implementation = "flash_attention_2" - # torch_dtype=torch.bfloat16 FIX ME - torch_dtype = torch.float16 - tokenizer.padding_side = "left" - - else: - attn_implementation = "eager" - torch_dtype = torch.float16 - tokenizer.padding_side = "right" - - llm = AutoModelForCausalLM.from_pretrained( - params.llm_path_or_name, - attn_implementation=attn_implementation, - torch_dtype=torch_dtype, - ) - if params.use_lora: - lora_config = LoraConfig( - r=64, - lora_alpha=16, - target_modules=[ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "up_proj", - "gate_proj", - "down_proj", - ], - task_type="CAUSAL_LM", - ) - llm = get_peft_model(llm, lora_config) - llm.print_trainable_parameters() - - special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]} - tokenizer.add_special_tokens(special_tokens_dict) - llm.config.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") - llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") - llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") - - llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( - DEFAULT_SPEECH_TOKEN - ) - - encoder_projector = EncoderProjector( - speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate - ) - - if params.enable_speech_output: - # Determine attn_implementation and torch_dtype based on use_flash_attn - if params.use_flash_attn: - attn_implementation = "flash_attention_2" - torch_dtype = torch.float16 # Or torch.bfloat16 if needed/supported - else: - attn_implementation = "eager" - torch_dtype = torch.float16 - - # codec_lm = AutoModelForCausalLM.from_pretrained( - # params.llm_path_or_name, - # attn_implementation=attn_implementation, - # torch_dtype=torch_dtype, - # ) - codec_vocab_size = 4096 + 4 - config = Qwen2Config( - vocab_size=codec_vocab_size, - hidden_size=1024, - num_hidden_layers=12, - num_attention_heads=16, - num_key_value_heads=16, - intermediate_size=2048, - max_position_embeddings=4096, - ) - # codec_lm = Qwen2ForCausalLM(config=config) - # Pass attn_implementation and torch_dtype to the constructor - # Use AutoModelForCausalLM.from_config for more generality - codec_lm = AutoModelForCausalLM.from_config( - config=config, - attn_implementation=attn_implementation, - torch_dtype=torch_dtype - ) - # cosyvoice2_token_size = 6561 - codec_lm.resize_token_embeddings(codec_vocab_size) - codec_lm.vocab_size = codec_vocab_size - codec_lm.config.pad_token_id = codec_vocab_size - 1 - codec_lm.config.eos_token_id = codec_vocab_size - 2 - codec_lm.config.bos_token_id = codec_vocab_size - 3 - codec_lm.config.mask_token_id = codec_vocab_size - 4 - # if params.use_lora: - # lora_config = LoraConfig( - # r=64, - # lora_alpha=16, - # target_modules=[ - # "q_proj", - # "k_proj", - # "v_proj", - # "o_proj", - # "up_proj", - # "gate_proj", - # "down_proj", - # ], - # lora_dropout=0.05, - # task_type="CAUSAL_LM", - # ) - # codec_lm = get_peft_model(codec_lm, lora_config) - # codec_lm.print_trainable_parameters() - else: - codec_lm = None - - model = SPEECH_LLM( - speech_encoder, - llm, - encoder_projector, - codec_lm, - codec_lm_padding_side= "left" if params.use_flash_attn else "right", - ) - - if params.avg > 1: - start = params.epoch - params.avg + 1 - assert start >= 1, start - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - assert "model" not in checkpoint - # deepspeed converted checkpoint only contains model state_dict - filenames = [ - f"{params.exp_dir}/epoch-{epoch}.pt" - for epoch in range(start, params.epoch + 1) - ] - avg_checkpoint = average_checkpoints(filenames) - model.load_state_dict(avg_checkpoint, strict=False) - - filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save(avg_checkpoint, filename) - else: - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - model.load_state_dict(checkpoint, strict=False) - - model.to(device) - model.eval() num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}")