diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 0815b6d3a..ddaf6078c 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -126,6 +126,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Whether to use lora to fine-tune llm.", ) + parser.add_argument( + "--unfreeze-llm", + type=str2bool, + default=False, + help="Whether to unfreeze llm during training.", + ) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -587,30 +594,30 @@ def train_one_epoch( valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) - - model.save_checkpoint( - save_dir=params.exp_dir, - tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", - client_state={}, - exclude_frozen_parameters=True - ) - - if rank == 0: - convert_zero_checkpoint_to_fp32_state_dict( - params.exp_dir, - f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", + if batch_idx != 0: + model.save_checkpoint( + save_dir=params.exp_dir, tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", - exclude_frozen_parameters=True, - ) - # save sampler state dict into checkpoint - sampler_state_dict = train_dl.sampler.state_dict() - torch.save( - sampler_state_dict, - f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt", - ) - os.system( - f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" + client_state={}, + exclude_frozen_parameters=True ) + + if rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + params.exp_dir, + f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", + tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", + exclude_frozen_parameters=True, + ) + # save sampler state dict into checkpoint + sampler_state_dict = train_dl.sampler.state_dict() + torch.save( + sampler_state_dict, + f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt", + ) + os.system( + f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" + ) try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( @@ -695,7 +702,10 @@ def run(rank, world_size, args): whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") speech_encoder = whisper_model.encoder speech_encoder_dim = whisper_model.dims.n_audio_state - + for name, param in speech_encoder.named_parameters(): + param.requires_grad = False + speech_encoder.eval() + tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) if params.use_flash_attn: attn_implementation = "flash_attention_2" @@ -713,16 +723,22 @@ def run(rank, world_size, args): attn_implementation=attn_implementation, torch_dtype=torch_dtype, ) - if params.use_lora: - lora_config = LoraConfig( - r=64, - lora_alpha=16, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"], - lora_dropout=0.05, - task_type="CAUSAL_LM", - ) - llm = get_peft_model(llm, lora_config) - llm.print_trainable_parameters() + + if not params.unfreeze_llm: + for name, param in llm.named_parameters(): + param.requires_grad = False + llm.eval() + else: + if params.use_lora: + lora_config = LoraConfig( + r=64, + lora_alpha=16, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"], + lora_dropout=0.05, + task_type="CAUSAL_LM", + ) + llm = get_peft_model(llm, lora_config) + llm.print_trainable_parameters() special_tokens_dict = { "additional_special_tokens": [DEFAULT_SPEECH_TOKEN] @@ -733,15 +749,6 @@ def run(rank, world_size, args): encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate) - for name, param in speech_encoder.named_parameters(): - param.requires_grad = False - speech_encoder.eval() - - if not params.use_lora: - for name, param in llm.named_parameters(): - param.requires_grad = False - llm.eval() - model = SPEECH_LLM( speech_encoder, llm,