mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
add unfreeze llm option
This commit is contained in:
parent
dbe85c1f12
commit
7db5445d1e
@ -126,6 +126,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Whether to use lora to fine-tune llm.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--unfreeze-llm",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to unfreeze llm during training.",
|
||||
)
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -587,7 +594,7 @@ def train_one_epoch(
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx != 0:
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
@ -695,6 +702,9 @@ def run(rank, world_size, args):
|
||||
whisper_model = whisper.load_model(params.speech_encoder_path_or_name, "cpu")
|
||||
speech_encoder = whisper_model.encoder
|
||||
speech_encoder_dim = whisper_model.dims.n_audio_state
|
||||
for name, param in speech_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
speech_encoder.eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
|
||||
if params.use_flash_attn:
|
||||
@ -713,6 +723,12 @@ def run(rank, world_size, args):
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
if not params.unfreeze_llm:
|
||||
for name, param in llm.named_parameters():
|
||||
param.requires_grad = False
|
||||
llm.eval()
|
||||
else:
|
||||
if params.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
@ -733,15 +749,6 @@ def run(rank, world_size, args):
|
||||
|
||||
encoder_projector = EncoderProjector(speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate)
|
||||
|
||||
for name, param in speech_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
speech_encoder.eval()
|
||||
|
||||
if not params.use_lora:
|
||||
for name, param in llm.named_parameters():
|
||||
param.requires_grad = False
|
||||
llm.eval()
|
||||
|
||||
model = SPEECH_LLM(
|
||||
speech_encoder,
|
||||
llm,
|
||||
|
Loading…
x
Reference in New Issue
Block a user