mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add lora for second stage training
This commit is contained in:
parent
3195a55ac7
commit
8226b628f4
@ -64,6 +64,7 @@ from icefall.utils import (
|
|||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
from train import DEFAULT_SPEECH_TOKEN
|
from train import DEFAULT_SPEECH_TOKEN
|
||||||
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||||
|
|
||||||
def average_checkpoints(
|
def average_checkpoints(
|
||||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||||
@ -138,6 +139,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Downsample rate for the encoder projector.",
|
help="Downsample rate for the encoder projector.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-flash-attn",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to use flash attention.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-lora",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use lora to fine-tune llm.",
|
||||||
|
)
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -191,10 +206,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-flash-attn",
|
"--use-aishell",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="Whether to use flash attention.",
|
help="Whether to only use aishell1 dataset for training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -495,6 +510,15 @@ def main():
|
|||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
|
if params.use_lora:
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=64,
|
||||||
|
lora_alpha=16,
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
llm = get_peft_model(llm, lora_config)
|
||||||
|
llm.print_trainable_parameters()
|
||||||
|
|
||||||
special_tokens_dict = {
|
special_tokens_dict = {
|
||||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||||
@ -560,9 +584,11 @@ def main():
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# test_sets_cuts = multi_dataset.test_cuts()
|
if params.use_aishell:
|
||||||
# test_sets_cuts = multi_dataset.aishell_test_cuts()
|
test_sets_cuts = multi_dataset.aishell_test_cuts()
|
||||||
test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts()
|
else:
|
||||||
|
# test_sets_cuts = multi_dataset.test_cuts()
|
||||||
|
test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts()
|
||||||
|
|
||||||
test_sets = test_sets_cuts.keys()
|
test_sets = test_sets_cuts.keys()
|
||||||
test_dls = [
|
test_dls = [
|
||||||
|
@ -37,15 +37,8 @@ class SPEECH_LLM(nn.Module):
|
|||||||
encoder_projector: nn.Module,
|
encoder_projector: nn.Module,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
for name, param in encoder.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
self.encoder.eval()
|
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
for name, param in llm.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
self.llm.eval()
|
|
||||||
self.encoder_projector = encoder_projector
|
self.encoder_projector = encoder_projector
|
||||||
|
|
||||||
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
|
def _merge_input_ids_with_speech_features(self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None):
|
||||||
|
@ -11,4 +11,4 @@ librosa
|
|||||||
deepspeed
|
deepspeed
|
||||||
transformers>=4.37.0
|
transformers>=4.37.0
|
||||||
flash-attn
|
flash-attn
|
||||||
|
peft
|
||||||
|
@ -80,6 +80,9 @@ from icefall.utils import (
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
import transformers
|
import transformers
|
||||||
from transformers.trainer_pt_utils import LabelSmoother
|
from transformers.trainer_pt_utils import LabelSmoother
|
||||||
|
|
||||||
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||||
|
|
||||||
#IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
#IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
||||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||||
|
|
||||||
@ -109,6 +112,19 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
default=1,
|
default=1,
|
||||||
help="Downsample rate for the encoder projector.",
|
help="Downsample rate for the encoder projector.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-flash-attn",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to use flash attention.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-lora",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use lora to fine-tune llm.",
|
||||||
|
)
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -240,10 +256,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-flash-attn",
|
"--use-aishell",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="Whether to use flash attention.",
|
help="Whether to only use aishell1 dataset for training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser = deepspeed.add_config_arguments(parser)
|
parser = deepspeed.add_config_arguments(parser)
|
||||||
@ -294,73 +310,6 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
# def load_checkpoint_if_available(
|
|
||||||
# params: AttributeDict,
|
|
||||||
# model: nn.Module,
|
|
||||||
# model_avg: nn.Module = None,
|
|
||||||
# optimizer: Optional[torch.optim.Optimizer] = None,
|
|
||||||
# scheduler: Optional[LRSchedulerType] = None,
|
|
||||||
# ) -> Optional[Dict[str, Any]]:
|
|
||||||
# """Load checkpoint from file.
|
|
||||||
|
|
||||||
# If params.start_batch is positive, it will load the checkpoint from
|
|
||||||
# `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
|
||||||
# params.start_epoch is larger than 1, it will load the checkpoint from
|
|
||||||
# `params.start_epoch - 1`.
|
|
||||||
|
|
||||||
# Apart from loading state dict for `model` and `optimizer` it also updates
|
|
||||||
# `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
|
||||||
# and `best_valid_loss` in `params`.
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# params:
|
|
||||||
# The return value of :func:`get_params`.
|
|
||||||
# model:
|
|
||||||
# The training model.
|
|
||||||
# model_avg:
|
|
||||||
# The stored model averaged from the start of training.
|
|
||||||
# optimizer:
|
|
||||||
# The optimizer that we are using.
|
|
||||||
# scheduler:
|
|
||||||
# The scheduler that we are using.
|
|
||||||
# Returns:
|
|
||||||
# Return a dict containing previously saved training info.
|
|
||||||
# """
|
|
||||||
# if params.start_batch > 0:
|
|
||||||
# filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
|
||||||
# elif params.start_epoch > 1:
|
|
||||||
# filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
|
||||||
# else:
|
|
||||||
# return None
|
|
||||||
|
|
||||||
# assert filename.is_file(), f"{filename} does not exist!"
|
|
||||||
|
|
||||||
# saved_params = load_checkpoint(
|
|
||||||
# filename,
|
|
||||||
# model=model,
|
|
||||||
# model_avg=model_avg,
|
|
||||||
# optimizer=optimizer,
|
|
||||||
# scheduler=scheduler,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# keys = [
|
|
||||||
# "best_train_epoch",
|
|
||||||
# "best_valid_epoch",
|
|
||||||
# "batch_idx_train",
|
|
||||||
# "best_train_loss",
|
|
||||||
# "best_valid_loss",
|
|
||||||
# ]
|
|
||||||
# for k in keys:
|
|
||||||
# params[k] = saved_params[k]
|
|
||||||
|
|
||||||
# if params.start_batch > 0:
|
|
||||||
# if "cur_epoch" in saved_params:
|
|
||||||
# params["start_epoch"] = saved_params["cur_epoch"]
|
|
||||||
|
|
||||||
# return saved_params
|
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
@ -764,6 +713,16 @@ def run(rank, world_size, args):
|
|||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
|
if params.use_lora:
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=64,
|
||||||
|
lora_alpha=16,
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
|
||||||
|
lora_dropout=0.05,
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
llm = get_peft_model(llm, lora_config)
|
||||||
|
llm.print_trainable_parameters()
|
||||||
|
|
||||||
special_tokens_dict = {
|
special_tokens_dict = {
|
||||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||||
@ -774,6 +733,15 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
||||||
|
|
||||||
|
for name, param in speech_encoder.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
speech_encoder.eval()
|
||||||
|
|
||||||
|
if not params.use_lora:
|
||||||
|
for name, param in llm.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
llm.eval()
|
||||||
|
|
||||||
model = SPEECH_LLM(
|
model = SPEECH_LLM(
|
||||||
speech_encoder,
|
speech_encoder,
|
||||||
llm,
|
llm,
|
||||||
@ -782,7 +750,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.pretrained_model_path:
|
if params.pretrained_model_path:
|
||||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||||
model.load_state_dict(checkpoint, strict=False)
|
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
||||||
|
assert len(unexpected_keys) == 0, unexpected_keys
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
@ -799,11 +768,6 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
# assert params.start_epoch > 0, params.start_epoch
|
|
||||||
# checkpoints = load_checkpoint_if_available(
|
|
||||||
# params=params, model=model, model_avg=model_avg
|
|
||||||
# )
|
|
||||||
|
|
||||||
assert params.deepspeed and world_size > 1
|
assert params.deepspeed and world_size > 1
|
||||||
logging.info("Using DeepSpeed")
|
logging.info("Using DeepSpeed")
|
||||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||||
@ -828,10 +792,12 @@ def run(rank, world_size, args):
|
|||||||
# )
|
# )
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if params.use_aishell:
|
||||||
|
train_cuts = multi_dataset.aishell_train_cuts()
|
||||||
|
else:
|
||||||
|
train_cuts = multi_dataset.train_cuts()
|
||||||
|
|
||||||
train_cuts = multi_dataset.train_cuts()
|
|
||||||
# train_cuts = multi_dataset.aishell_train_cuts()
|
|
||||||
# train_cuts = multi_dataset.aishell2_train_cuts()
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
# if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
# if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
@ -846,8 +812,10 @@ def run(rank, world_size, args):
|
|||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
# valid_cuts = multi_dataset.dev_cuts()
|
if params.use_aishell:
|
||||||
valid_cuts = multi_dataset.aishell_dev_cuts()
|
valid_cuts = multi_dataset.aishell_dev_cuts()
|
||||||
|
else:
|
||||||
|
valid_cuts = multi_dataset.dev_cuts()
|
||||||
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
if args.tensorboard and rank == 0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user