add unfreeze llm option

This commit is contained in:
root 2024-06-13 09:27:07 +00:00 committed by Yuekai Zhang
parent dbe85c1f12
commit 7db5445d1e

View File

@ -126,6 +126,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Whether to use lora to fine-tune llm.", help="Whether to use lora to fine-tune llm.",
) )
parser.add_argument(
"--unfreeze-llm",
type=str2bool,
default=False,
help="Whether to unfreeze llm during training.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -587,30 +594,30 @@ def train_one_epoch(
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
if batch_idx != 0:
model.save_checkpoint( model.save_checkpoint(
save_dir=params.exp_dir, save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
client_state={},
exclude_frozen_parameters=True
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
exclude_frozen_parameters=True, client_state={},
) exclude_frozen_parameters=True
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
)
os.system(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
) )
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
sampler_state_dict = train_dl.sampler.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
)
os.system(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
)
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
@ -695,6 +702,9 @@ def run(rank, world_size, args):
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu") whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
speech_encoder = whisper_model.encoder speech_encoder = whisper_model.encoder
speech_encoder_dim = whisper_model.dims.n_audio_state speech_encoder_dim = whisper_model.dims.n_audio_state
for name, param in speech_encoder.named_parameters():
param.requires_grad = False
speech_encoder.eval()
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn: if params.use_flash_attn:
@ -713,16 +723,22 @@ def run(rank, world_size, args):
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
) )
if params.use_lora:
lora_config = LoraConfig( if not params.unfreeze_llm:
r=64, for name, param in llm.named_parameters():
lora_alpha=16, param.requires_grad = False
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"], llm.eval()
lora_dropout=0.05, else:
task_type="CAUSAL_LM", if params.use_lora:
) lora_config = LoraConfig(
llm = get_peft_model(llm, lora_config) r=64,
llm.print_trainable_parameters() lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
special_tokens_dict = { special_tokens_dict = {
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN] "additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
@ -733,15 +749,6 @@ def run(rank, world_size, args):
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate) encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
for name, param in speech_encoder.named_parameters():
param.requires_grad = False
speech_encoder.eval()
if not params.use_lora:
for name, param in llm.named_parameters():
param.requires_grad = False
llm.eval()
model = SPEECH_LLM( model = SPEECH_LLM(
speech_encoder, speech_encoder,
llm, llm,