mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
add unfreeze llm option
This commit is contained in:
parent
dbe85c1f12
commit
7db5445d1e
@ -126,6 +126,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Whether to use lora to fine-tune llm.",
|
help="Whether to use lora to fine-tune llm.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--unfreeze-llm",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to unfreeze llm during training.",
|
||||||
|
)
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -587,30 +594,30 @@ def train_one_epoch(
|
|||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
if batch_idx != 0:
|
||||||
model.save_checkpoint(
|
model.save_checkpoint(
|
||||||
save_dir=params.exp_dir,
|
save_dir=params.exp_dir,
|
||||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
|
||||||
client_state={},
|
|
||||||
exclude_frozen_parameters=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
|
||||||
params.exp_dir,
|
|
||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
|
||||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||||
exclude_frozen_parameters=True,
|
client_state={},
|
||||||
)
|
exclude_frozen_parameters=True
|
||||||
# save sampler state dict into checkpoint
|
|
||||||
sampler_state_dict = train_dl.sampler.state_dict()
|
|
||||||
torch.save(
|
|
||||||
sampler_state_dict,
|
|
||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
|
|
||||||
)
|
|
||||||
os.system(
|
|
||||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
|
params.exp_dir,
|
||||||
|
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||||
|
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||||
|
exclude_frozen_parameters=True,
|
||||||
|
)
|
||||||
|
# save sampler state dict into checkpoint
|
||||||
|
sampler_state_dict = train_dl.sampler.state_dict()
|
||||||
|
torch.save(
|
||||||
|
sampler_state_dict,
|
||||||
|
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt",
|
||||||
|
)
|
||||||
|
os.system(
|
||||||
|
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
@ -695,7 +702,10 @@ def run(rank, world_size, args):
|
|||||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||||
speech_encoder = whisper_model.encoder
|
speech_encoder = whisper_model.encoder
|
||||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||||
|
for name, param in speech_encoder.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
speech_encoder.eval()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||||
if params.use_flash_attn:
|
if params.use_flash_attn:
|
||||||
attn_implementation = "flash_attention_2"
|
attn_implementation = "flash_attention_2"
|
||||||
@ -713,16 +723,22 @@ def run(rank, world_size, args):
|
|||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
if params.use_lora:
|
|
||||||
lora_config = LoraConfig(
|
if not params.unfreeze_llm:
|
||||||
r=64,
|
for name, param in llm.named_parameters():
|
||||||
lora_alpha=16,
|
param.requires_grad = False
|
||||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
|
llm.eval()
|
||||||
lora_dropout=0.05,
|
else:
|
||||||
task_type="CAUSAL_LM",
|
if params.use_lora:
|
||||||
)
|
lora_config = LoraConfig(
|
||||||
llm = get_peft_model(llm, lora_config)
|
r=64,
|
||||||
llm.print_trainable_parameters()
|
lora_alpha=16,
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
|
||||||
|
lora_dropout=0.05,
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
llm = get_peft_model(llm, lora_config)
|
||||||
|
llm.print_trainable_parameters()
|
||||||
|
|
||||||
special_tokens_dict = {
|
special_tokens_dict = {
|
||||||
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]
|
||||||
@ -733,15 +749,6 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
||||||
|
|
||||||
for name, param in speech_encoder.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
speech_encoder.eval()
|
|
||||||
|
|
||||||
if not params.use_lora:
|
|
||||||
for name, param in llm.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
llm.eval()
|
|
||||||
|
|
||||||
model = SPEECH_LLM(
|
model = SPEECH_LLM(
|
||||||
speech_encoder,
|
speech_encoder,
|
||||||
llm,
|
llm,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user