mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
add lora for second stage training
This commit is contained in:
parent
3195a55ac7
commit
8226b628f4
@ -64,6 +64,7 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
from train import DEFAULT_SPEECH_TOKEN
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
def average_checkpoints(
|
||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||
@ -138,6 +139,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Downsample rate for the encoder projector.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-flash-attn",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use flash attention.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-lora",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use lora to fine-tune llm.",
|
||||
)
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -191,10 +206,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-flash-attn",
|
||||
"--use-aishell",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use flash attention.",
|
||||
help="Whether to only use aishell1 dataset for training.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -495,6 +510,15 @@ def main():
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
llm = get_peft_model(llm, lora_config)
|
||||
llm.print_trainable_parameters()
|
||||
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||
@ -560,9 +584,11 @@ def main():
|
||||
return False
|
||||
return True
|
||||
|
||||
# test_sets_cuts = multi_dataset.test_cuts()
|
||||
# test_sets_cuts = multi_dataset.aishell_test_cuts()
|
||||
test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts()
|
||||
if params.use_aishell:
|
||||
test_sets_cuts = multi_dataset.aishell_test_cuts()
|
||||
else:
|
||||
# test_sets_cuts = multi_dataset.test_cuts()
|
||||
test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts()
|
||||
|
||||
test_sets = test_sets_cuts.keys()
|
||||
test_dls = [
|
||||
|
@ -37,15 +37,8 @@ class SPEECH_LLM(nn.Module):
|
||||
encoder_projector: nn.Module,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = encoder
|
||||
for name, param in encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.encoder.eval()
|
||||
self.llm = llm
|
||||
for name, param in llm.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.llm.eval()
|
||||
self.encoder_projector = encoder_projector
|
||||
|
||||
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
|
||||
|
@ -11,4 +11,4 @@ librosa
|
||||
deepspeed
|
||||
transformers>=4.37.0
|
||||
flash-attn
|
||||
|
||||
peft
|
||||
|
@ -80,6 +80,9 @@ from icefall.utils import (
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import transformers
|
||||
from transformers.trainer_pt_utils import LabelSmoother
|
||||
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
#IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||
|
||||
@ -109,6 +112,19 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
default=1,
|
||||
help="Downsample rate for the encoder projector.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-flash-attn",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use flash attention.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-lora",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use lora to fine-tune llm.",
|
||||
)
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -240,10 +256,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-flash-attn",
|
||||
"--use-aishell",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to use flash attention.",
|
||||
help="Whether to only use aishell1 dataset for training.",
|
||||
)
|
||||
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
@ -294,73 +310,6 @@ def get_params() -> AttributeDict:
|
||||
|
||||
return params
|
||||
|
||||
|
||||
# def load_checkpoint_if_available(
|
||||
# params: AttributeDict,
|
||||
# model: nn.Module,
|
||||
# model_avg: nn.Module = None,
|
||||
# optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
# scheduler: Optional[LRSchedulerType] = None,
|
||||
# ) -> Optional[Dict[str, Any]]:
|
||||
# """Load checkpoint from file.
|
||||
|
||||
# If params.start_batch is positive, it will load the checkpoint from
|
||||
# `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||
# params.start_epoch is larger than 1, it will load the checkpoint from
|
||||
# `params.start_epoch - 1`.
|
||||
|
||||
# Apart from loading state dict for `model` and `optimizer` it also updates
|
||||
# `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
# and `best_valid_loss` in `params`.
|
||||
|
||||
# Args:
|
||||
# params:
|
||||
# The return value of :func:`get_params`.
|
||||
# model:
|
||||
# The training model.
|
||||
# model_avg:
|
||||
# The stored model averaged from the start of training.
|
||||
# optimizer:
|
||||
# The optimizer that we are using.
|
||||
# scheduler:
|
||||
# The scheduler that we are using.
|
||||
# Returns:
|
||||
# Return a dict containing previously saved training info.
|
||||
# """
|
||||
# if params.start_batch > 0:
|
||||
# filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||
# elif params.start_epoch > 1:
|
||||
# filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
# else:
|
||||
# return None
|
||||
|
||||
# assert filename.is_file(), f"{filename} does not exist!"
|
||||
|
||||
# saved_params = load_checkpoint(
|
||||
# filename,
|
||||
# model=model,
|
||||
# model_avg=model_avg,
|
||||
# optimizer=optimizer,
|
||||
# scheduler=scheduler,
|
||||
# )
|
||||
|
||||
# keys = [
|
||||
# "best_train_epoch",
|
||||
# "best_valid_epoch",
|
||||
# "batch_idx_train",
|
||||
# "best_train_loss",
|
||||
# "best_valid_loss",
|
||||
# ]
|
||||
# for k in keys:
|
||||
# params[k] = saved_params[k]
|
||||
|
||||
# if params.start_batch > 0:
|
||||
# if "cur_epoch" in saved_params:
|
||||
# params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
# return saved_params
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: AutoTokenizer,
|
||||
@ -764,6 +713,16 @@ def run(rank, world_size, args):
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
|
||||
lora_dropout=0.05,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
llm = get_peft_model(llm, lora_config)
|
||||
llm.print_trainable_parameters()
|
||||
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||
@ -774,6 +733,15 @@ def run(rank, world_size, args):
|
||||
|
||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
||||
|
||||
for name, param in speech_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
speech_encoder.eval()
|
||||
|
||||
if not params.use_lora:
|
||||
for name, param in llm.named_parameters():
|
||||
param.requires_grad = False
|
||||
llm.eval()
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder,
|
||||
llm,
|
||||
@ -782,7 +750,8 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.pretrained_model_path:
|
||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
||||
assert len(unexpected_keys) == 0, unexpected_keys
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
@ -799,11 +768,6 @@ def run(rank, world_size, args):
|
||||
logging.info(f"Device: {device}")
|
||||
model.to(device)
|
||||
|
||||
# assert params.start_epoch > 0, params.start_epoch
|
||||
# checkpoints = load_checkpoint_if_available(
|
||||
# params=params, model=model, model_avg=model_avg
|
||||
# )
|
||||
|
||||
assert params.deepspeed and world_size > 1
|
||||
logging.info("Using DeepSpeed")
|
||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||
@ -828,10 +792,12 @@ def run(rank, world_size, args):
|
||||
# )
|
||||
return False
|
||||
return True
|
||||
|
||||
if params.use_aishell:
|
||||
train_cuts = multi_dataset.aishell_train_cuts()
|
||||
else:
|
||||
train_cuts = multi_dataset.train_cuts()
|
||||
|
||||
train_cuts = multi_dataset.train_cuts()
|
||||
# train_cuts = multi_dataset.aishell_train_cuts()
|
||||
# train_cuts = multi_dataset.aishell2_train_cuts()
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
# if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
@ -846,8 +812,10 @@ def run(rank, world_size, args):
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
# valid_cuts = multi_dataset.dev_cuts()
|
||||
valid_cuts = multi_dataset.aishell_dev_cuts()
|
||||
if params.use_aishell:
|
||||
valid_cuts = multi_dataset.aishell_dev_cuts()
|
||||
else:
|
||||
valid_cuts = multi_dataset.dev_cuts()
|
||||
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user