diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 88b831f2f..e2eb77c6b 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -64,6 +64,7 @@ from icefall.utils import ( write_error_stats, ) from train import DEFAULT_SPEECH_TOKEN +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training def average_checkpoints( filenames: List[Path], device: torch.device = torch.device("cpu") @@ -138,6 +139,20 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Downsample rate for the encoder projector.", ) + parser.add_argument( + "--use-flash-attn", + type=str2bool, + default=True, + help="Whether to use flash attention.", + ) + + parser.add_argument( + "--use-lora", + type=str2bool, + default=False, + help="Whether to use lora to fine-tune llm.", + ) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -191,10 +206,10 @@ def get_parser(): ) parser.add_argument( - "--use-flash-attn", + "--use-aishell", type=str2bool, default=True, - help="Whether to use flash attention.", + help="Whether to only use aishell1 dataset for training.", ) add_model_arguments(parser) @@ -495,6 +510,15 @@ def main(): attn_implementation=attn_implementation, torch_dtype=torch_dtype, ) + if params.use_lora: + lora_config = LoraConfig( + r=64, + lora_alpha=16, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"], + task_type="CAUSAL_LM", + ) + llm = get_peft_model(llm, lora_config) + llm.print_trainable_parameters() special_tokens_dict = { "additional_special_tokens": [DEFAULT_SPEECH_TOKEN] @@ -560,9 +584,11 @@ def main(): return False return True - # test_sets_cuts = multi_dataset.test_cuts() - # test_sets_cuts = multi_dataset.aishell_test_cuts() - test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts() + if params.use_aishell: + test_sets_cuts = multi_dataset.aishell_test_cuts() + else: + # test_sets_cuts = multi_dataset.test_cuts() + test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts() test_sets = test_sets_cuts.keys() test_dls = [ diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py index df86f87af..440724db2 100644 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py @@ -37,15 +37,8 @@ class SPEECH_LLM(nn.Module): encoder_projector: nn.Module, ): super().__init__() - self.encoder = encoder - for name, param in encoder.named_parameters(): - param.requires_grad = False - self.encoder.eval() self.llm = llm - for name, param in llm.named_parameters(): - param.requires_grad = False - self.llm.eval() self.encoder_projector = encoder_projector def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None): diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt b/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt index bde73601f..c5a90cb08 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/requirements.txt @@ -11,4 +11,4 @@ librosa deepspeed transformers>=4.37.0 flash-attn - +peft diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 21b615930..99b8dae0b 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -80,6 +80,9 @@ from icefall.utils import ( from transformers import AutoModelForCausalLM, AutoTokenizer import transformers from transformers.trainer_pt_utils import LabelSmoother + +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training + #IGNORE_TOKEN_ID = LabelSmoother.ignore_index DEFAULT_SPEECH_TOKEN = "" @@ -109,6 +112,19 @@ def add_model_arguments(parser: argparse.ArgumentParser): default=1, help="Downsample rate for the encoder projector.", ) + parser.add_argument( + "--use-flash-attn", + type=str2bool, + default=True, + help="Whether to use flash attention.", + ) + + parser.add_argument( + "--use-lora", + type=str2bool, + default=False, + help="Whether to use lora to fine-tune llm.", + ) def get_parser(): parser = argparse.ArgumentParser( @@ -240,10 +256,10 @@ def get_parser(): ) parser.add_argument( - "--use-flash-attn", + "--use-aishell", type=str2bool, default=True, - help="Whether to use flash attention.", + help="Whether to only use aishell1 dataset for training.", ) parser = deepspeed.add_config_arguments(parser) @@ -294,73 +310,6 @@ def get_params() -> AttributeDict: return params - -# def load_checkpoint_if_available( -# params: AttributeDict, -# model: nn.Module, -# model_avg: nn.Module = None, -# optimizer: Optional[torch.optim.Optimizer] = None, -# scheduler: Optional[LRSchedulerType] = None, -# ) -> Optional[Dict[str, Any]]: -# """Load checkpoint from file. - -# If params.start_batch is positive, it will load the checkpoint from -# `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if -# params.start_epoch is larger than 1, it will load the checkpoint from -# `params.start_epoch - 1`. - -# Apart from loading state dict for `model` and `optimizer` it also updates -# `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, -# and `best_valid_loss` in `params`. - -# Args: -# params: -# The return value of :func:`get_params`. -# model: -# The training model. -# model_avg: -# The stored model averaged from the start of training. -# optimizer: -# The optimizer that we are using. -# scheduler: -# The scheduler that we are using. -# Returns: -# Return a dict containing previously saved training info. -# """ -# if params.start_batch > 0: -# filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" -# elif params.start_epoch > 1: -# filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" -# else: -# return None - -# assert filename.is_file(), f"{filename} does not exist!" - -# saved_params = load_checkpoint( -# filename, -# model=model, -# model_avg=model_avg, -# optimizer=optimizer, -# scheduler=scheduler, -# ) - -# keys = [ -# "best_train_epoch", -# "best_valid_epoch", -# "batch_idx_train", -# "best_train_loss", -# "best_valid_loss", -# ] -# for k in keys: -# params[k] = saved_params[k] - -# if params.start_batch > 0: -# if "cur_epoch" in saved_params: -# params["start_epoch"] = saved_params["cur_epoch"] - -# return saved_params - - def compute_loss( params: AttributeDict, tokenizer: AutoTokenizer, @@ -764,6 +713,16 @@ def run(rank, world_size, args): attn_implementation=attn_implementation, torch_dtype=torch_dtype, ) + if params.use_lora: + lora_config = LoraConfig( + r=64, + lora_alpha=16, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"], + lora_dropout=0.05, + task_type="CAUSAL_LM", + ) + llm = get_peft_model(llm, lora_config) + llm.print_trainable_parameters() special_tokens_dict = { "additional_special_tokens": [DEFAULT_SPEECH_TOKEN] @@ -774,6 +733,15 @@ def run(rank, world_size, args): encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate) + for name, param in speech_encoder.named_parameters(): + param.requires_grad = False + speech_encoder.eval() + + if not params.use_lora: + for name, param in llm.named_parameters(): + param.requires_grad = False + llm.eval() + model = SPEECH_LLM( speech_encoder, llm, @@ -782,7 +750,8 @@ def run(rank, world_size, args): if params.pretrained_model_path: checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") - model.load_state_dict(checkpoint, strict=False) + missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + assert len(unexpected_keys) == 0, unexpected_keys num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -799,11 +768,6 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") model.to(device) - # assert params.start_epoch > 0, params.start_epoch - # checkpoints = load_checkpoint_if_available( - # params=params, model=model, model_avg=model_avg - # ) - assert params.deepspeed and world_size > 1 logging.info("Using DeepSpeed") model, optimizer, _, scheduler = deepspeed.initialize( @@ -828,10 +792,12 @@ def run(rank, world_size, args): # ) return False return True + + if params.use_aishell: + train_cuts = multi_dataset.aishell_train_cuts() + else: + train_cuts = multi_dataset.train_cuts() - train_cuts = multi_dataset.train_cuts() - # train_cuts = multi_dataset.aishell_train_cuts() - # train_cuts = multi_dataset.aishell2_train_cuts() train_cuts = train_cuts.filter(remove_short_and_long_utt) # if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: @@ -846,8 +812,10 @@ def run(rank, world_size, args): train_cuts, sampler_state_dict=sampler_state_dict ) - # valid_cuts = multi_dataset.dev_cuts() - valid_cuts = multi_dataset.aishell_dev_cuts() + if params.use_aishell: + valid_cuts = multi_dataset.aishell_dev_cuts() + else: + valid_cuts = multi_dataset.dev_cuts() valid_dl = data_module.valid_dataloaders(valid_cuts) if args.tensorboard and rank == 0: