diff --git a/egs/librispeech/ASR/whisper/train.py b/egs/librispeech/ASR/whisper/train.py index bd6b27d99..db6f2e182 100755 --- a/egs/librispeech/ASR/whisper/train.py +++ b/egs/librispeech/ASR/whisper/train.py @@ -88,15 +88,6 @@ from icefall.utils import ( LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -226,6 +217,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--freeze-modules", + type=str, + default=None, + help="Which modules to freeze during finetune" + ) + parser = deepspeed.add_config_arguments(parser) return parser @@ -583,6 +581,9 @@ def train_one_epoch( be set to 0. """ model.train() + for name, module in model.named_modules(): + if name.startswith(params.freeze_modules): + module.eval() tot_loss = MetricsTracker() @@ -630,7 +631,6 @@ def train_one_epoch( model.step() else: scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) @@ -739,8 +739,19 @@ def run(rank, world_size, args): replace_whisper_encoder_forward() model = whisper.load_model(params.model_name, "cpu") del model.alignment_heads + + if params.freeze_modules is not None: + for name, p in model.named_parameters(): + if name.startswith(params.freeze_modules): + p.requires_grad = False + logging.info(f"Do not update {name}") + for name, module in model.named_modules(): + if name.startswith(params.freeze_modules): + module.eval() + num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") + num_trainable = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}. Total trainable parameters: {num_trainable}") tokenizer = whisper.tokenizer.get_tokenizer( model.is_multilingual,