support freezing modules

This commit is contained in:
marcoyang 2024-03-28 18:16:33 +08:00
parent 360f208037
commit cfbc829df3

View File

@ -88,15 +88,6 @@ from icefall.utils import (
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
for module in model.modules():
if hasattr(module, "batch_count"):
module.batch_count = batch_count
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -226,6 +217,13 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--freeze-modules",
type=str,
default=None,
help="Which modules to freeze during finetune"
)
parser = deepspeed.add_config_arguments(parser) parser = deepspeed.add_config_arguments(parser)
return parser return parser
@ -583,6 +581,9 @@ def train_one_epoch(
be set to 0. be set to 0.
""" """
model.train() model.train()
for name, module in model.named_modules():
if name.startswith(params.freeze_modules):
module.eval()
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
@ -630,7 +631,6 @@ def train_one_epoch(
model.step() model.step()
else: else:
scaler.scale(loss).backward() scaler.scale(loss).backward()
set_batch_count(model, params.batch_idx_train)
scheduler.step_batch(params.batch_idx_train) scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer) scaler.step(optimizer)
@ -739,8 +739,19 @@ def run(rank, world_size, args):
replace_whisper_encoder_forward() replace_whisper_encoder_forward()
model = whisper.load_model(params.model_name, "cpu") model = whisper.load_model(params.model_name, "cpu")
del model.alignment_heads del model.alignment_heads
if params.freeze_modules is not None:
for name, p in model.named_parameters():
if name.startswith(params.freeze_modules):
p.requires_grad = False
logging.info(f"Do not update {name}")
for name, module in model.named_modules():
if name.startswith(params.freeze_modules):
module.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") num_trainable = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}. Total trainable parameters: {num_trainable}")
tokenizer = whisper.tokenizer.get_tokenizer( tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual, model.is_multilingual,