support freezing modules

This commit is contained in:
marcoyang 2024-03-28 18:16:33 +08:00
parent 360f208037
commit cfbc829df3

View File

@ -88,15 +88,6 @@ from icefall.utils import (
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
for module in model.modules():
if hasattr(module, "batch_count"):
module.batch_count = batch_count
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -226,6 +217,13 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--freeze-modules",
type=str,
default=None,
help="Which modules to freeze during finetune"
)
parser = deepspeed.add_config_arguments(parser)
return parser
@ -583,6 +581,9 @@ def train_one_epoch(
be set to 0.
"""
model.train()
for name, module in model.named_modules():
if name.startswith(params.freeze_modules):
module.eval()
tot_loss = MetricsTracker()
@ -630,7 +631,6 @@ def train_one_epoch(
model.step()
else:
scaler.scale(loss).backward()
set_batch_count(model, params.batch_idx_train)
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
@ -739,8 +739,19 @@ def run(rank, world_size, args):
replace_whisper_encoder_forward()
model = whisper.load_model(params.model_name, "cpu")
del model.alignment_heads
if params.freeze_modules is not None:
for name, p in model.named_parameters():
if name.startswith(params.freeze_modules):
p.requires_grad = False
logging.info(f"Do not update {name}")
for name, module in model.named_modules():
if name.startswith(params.freeze_modules):
module.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
num_trainable = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}. Total trainable parameters: {num_trainable}")
tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual,