mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
support freezing modules
This commit is contained in:
parent
360f208037
commit
cfbc829df3
@ -88,15 +88,6 @@ from icefall.utils import (
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
for module in model.modules():
|
||||
if hasattr(module, "batch_count"):
|
||||
module.batch_count = batch_count
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -226,6 +217,13 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--freeze-modules",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Which modules to freeze during finetune"
|
||||
)
|
||||
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -583,6 +581,9 @@ def train_one_epoch(
|
||||
be set to 0.
|
||||
"""
|
||||
model.train()
|
||||
for name, module in model.named_modules():
|
||||
if name.startswith(params.freeze_modules):
|
||||
module.eval()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
@ -630,7 +631,6 @@ def train_one_epoch(
|
||||
model.step()
|
||||
else:
|
||||
scaler.scale(loss).backward()
|
||||
set_batch_count(model, params.batch_idx_train)
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
|
||||
scaler.step(optimizer)
|
||||
@ -739,8 +739,19 @@ def run(rank, world_size, args):
|
||||
replace_whisper_encoder_forward()
|
||||
model = whisper.load_model(params.model_name, "cpu")
|
||||
del model.alignment_heads
|
||||
|
||||
if params.freeze_modules is not None:
|
||||
for name, p in model.named_parameters():
|
||||
if name.startswith(params.freeze_modules):
|
||||
p.requires_grad = False
|
||||
logging.info(f"Do not update {name}")
|
||||
for name, module in model.named_modules():
|
||||
if name.startswith(params.freeze_modules):
|
||||
module.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
num_trainable = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}. Total trainable parameters: {num_trainable}")
|
||||
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||
model.is_multilingual,
|
||||
|
Loading…
x
Reference in New Issue
Block a user