add lora for second stage training

This commit is contained in:
root 2024-06-13 07:00:19 +00:00 committed by Yuekai Zhang
parent 3195a55ac7
commit 8226b628f4
4 changed files with 80 additions and 93 deletions

View File

@ -64,6 +64,7 @@ from icefall.utils import (
write_error_stats,
)
from train import DEFAULT_SPEECH_TOKEN
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
@ -138,6 +139,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Downsample rate for the encoder projector.",
)
parser.add_argument(
"--use-flash-attn",
type=str2bool,
default=True,
help="Whether to use flash attention.",
)
parser.add_argument(
"--use-lora",
type=str2bool,
default=False,
help="Whether to use lora to fine-tune llm.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -191,10 +206,10 @@ def get_parser():
)
parser.add_argument(
"--use-flash-attn",
"--use-aishell",
type=str2bool,
default=True,
help="Whether to use flash attention.",
help="Whether to only use aishell1 dataset for training.",
)
add_model_arguments(parser)
@ -495,6 +510,15 @@ def main():
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = {
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
@ -560,9 +584,11 @@ def main():
return False
return True
# test_sets_cuts = multi_dataset.test_cuts()
# test_sets_cuts = multi_dataset.aishell_test_cuts()
test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts()
if params.use_aishell:
test_sets_cuts = multi_dataset.aishell_test_cuts()
else:
# test_sets_cuts = multi_dataset.test_cuts()
test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts()
test_sets = test_sets_cuts.keys()
test_dls = [

View File

@ -37,15 +37,8 @@ class SPEECH_LLM(nn.Module):
encoder_projector: nn.Module,
):
super().__init__()
self.encoder = encoder
for name, param in encoder.named_parameters():
param.requires_grad = False
self.encoder.eval()
self.llm = llm
for name, param in llm.named_parameters():
param.requires_grad = False
self.llm.eval()
self.encoder_projector = encoder_projector
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):

View File

@ -11,4 +11,4 @@ librosa
deepspeed
transformers>=4.37.0
flash-attn
peft

View File

@ -80,6 +80,9 @@ from icefall.utils import (
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers
from transformers.trainer_pt_utils import LabelSmoother
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
#IGNORE_TOKEN_ID = LabelSmoother.ignore_index
DEFAULT_SPEECH_TOKEN = "<speech>"
@ -109,6 +112,19 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default=1,
help="Downsample rate for the encoder projector.",
)
parser.add_argument(
"--use-flash-attn",
type=str2bool,
default=True,
help="Whether to use flash attention.",
)
parser.add_argument(
"--use-lora",
type=str2bool,
default=False,
help="Whether to use lora to fine-tune llm.",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -240,10 +256,10 @@ def get_parser():
)
parser.add_argument(
"--use-flash-attn",
"--use-aishell",
type=str2bool,
default=True,
help="Whether to use flash attention.",
help="Whether to only use aishell1 dataset for training.",
)
parser = deepspeed.add_config_arguments(parser)
@ -294,73 +310,6 @@ def get_params() -> AttributeDict:
return params
# def load_checkpoint_if_available(
# params: AttributeDict,
# model: nn.Module,
# model_avg: nn.Module = None,
# optimizer: Optional[torch.optim.Optimizer] = None,
# scheduler: Optional[LRSchedulerType] = None,
# ) -> Optional[Dict[str, Any]]:
# """Load checkpoint from file.
# If params.start_batch is positive, it will load the checkpoint from
# `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
# params.start_epoch is larger than 1, it will load the checkpoint from
# `params.start_epoch - 1`.
# Apart from loading state dict for `model` and `optimizer` it also updates
# `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
# and `best_valid_loss` in `params`.
# Args:
# params:
# The return value of :func:`get_params`.
# model:
# The training model.
# model_avg:
# The stored model averaged from the start of training.
# optimizer:
# The optimizer that we are using.
# scheduler:
# The scheduler that we are using.
# Returns:
# Return a dict containing previously saved training info.
# """
# if params.start_batch > 0:
# filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
# elif params.start_epoch > 1:
# filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
# else:
# return None
# assert filename.is_file(), f"{filename} does not exist!"
# saved_params = load_checkpoint(
# filename,
# model=model,
# model_avg=model_avg,
# optimizer=optimizer,
# scheduler=scheduler,
# )
# keys = [
# "best_train_epoch",
# "best_valid_epoch",
# "batch_idx_train",
# "best_train_loss",
# "best_valid_loss",
# ]
# for k in keys:
# params[k] = saved_params[k]
# if params.start_batch > 0:
# if "cur_epoch" in saved_params:
# params["start_epoch"] = saved_params["cur_epoch"]
# return saved_params
def compute_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
@ -764,6 +713,16 @@ def run(rank, world_size, args):
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
if params.use_lora:
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = {
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
@ -774,6 +733,15 @@ def run(rank, world_size, args):
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
for name, param in speech_encoder.named_parameters():
param.requires_grad = False
speech_encoder.eval()
if not params.use_lora:
for name, param in llm.named_parameters():
param.requires_grad = False
llm.eval()
model = SPEECH_LLM(
speech_encoder,
llm,
@ -782,7 +750,8 @@ def run(rank, world_size, args):
if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
model.load_state_dict(checkpoint, strict=False)
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
assert len(unexpected_keys) == 0, unexpected_keys
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -799,11 +768,6 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")
model.to(device)
# assert params.start_epoch > 0, params.start_epoch
# checkpoints = load_checkpoint_if_available(
# params=params, model=model, model_avg=model_avg
# )
assert params.deepspeed and world_size > 1
logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize(
@ -828,10 +792,12 @@ def run(rank, world_size, args):
# )
return False
return True
if params.use_aishell:
train_cuts = multi_dataset.aishell_train_cuts()
else:
train_cuts = multi_dataset.train_cuts()
train_cuts = multi_dataset.train_cuts()
# train_cuts = multi_dataset.aishell_train_cuts()
# train_cuts = multi_dataset.aishell2_train_cuts()
train_cuts = train_cuts.filter(remove_short_and_long_utt)
# if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
@ -846,8 +812,10 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict
)
# valid_cuts = multi_dataset.dev_cuts()
valid_cuts = multi_dataset.aishell_dev_cuts()
if params.use_aishell:
valid_cuts = multi_dataset.aishell_dev_cuts()
else:
valid_cuts = multi_dataset.dev_cuts()
valid_dl = data_module.valid_dataloaders(valid_cuts)
if args.tensorboard and rank == 0: