mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
add unfreeze llm option
This commit is contained in:
parent
dbe85c1f12
commit
7db5445d1e
@ -126,6 +126,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Whether to use lora to fine-tune llm.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--unfreeze-llm",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to unfreeze llm during training.",
|
||||
)
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -587,30 +594,30 @@ def train_one_epoch(
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
client_state={},
|
||||
exclude_frozen_parameters=True
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||
if batch_idx != 0:
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
# save sampler state dict into checkpoint
|
||||
sampler_state_dict = train_dl.sampler.state_dict()
|
||||
torch.save(
|
||||
sampler_state_dict,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
|
||||
)
|
||||
os.system(
|
||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||
client_state={},
|
||||
exclude_frozen_parameters=True
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
# save sampler state dict into checkpoint
|
||||
sampler_state_dict = train_dl.sampler.state_dict()
|
||||
torch.save(
|
||||
sampler_state_dict,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
|
||||
)
|
||||
os.system(
|
||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||
)
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
@ -695,7 +702,10 @@ def run(rank, world_size, args):
|
||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||
speech_encoder = whisper_model.encoder
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
|
||||
for name, param in speech_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
speech_encoder.eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
@ -713,16 +723,22 @@ def run(rank, world_size, args):
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
|
||||
lora_dropout=0.05,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
llm = get_peft_model(llm, lora_config)
|
||||
llm.print_trainable_parameters()
|
||||
|
||||
if not params.unfreeze_llm:
|
||||
for name, param in llm.named_parameters():
|
||||
param.requires_grad = False
|
||||
llm.eval()
|
||||
else:
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
|
||||
lora_dropout=0.05,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
llm = get_peft_model(llm, lora_config)
|
||||
llm.print_trainable_parameters()
|
||||
|
||||
special_tokens_dict = {
|
||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||
@ -733,15 +749,6 @@ def run(rank, world_size, args):
|
||||
|
||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
||||
|
||||
for name, param in speech_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
speech_encoder.eval()
|
||||
|
||||
if not params.use_lora:
|
||||
for name, param in llm.named_parameters():
|
||||
param.requires_grad = False
|
||||
llm.eval()
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder,
|
||||
llm,
|
||||
|
Loading…
x
Reference in New Issue
Block a user