mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
support freezing modules
This commit is contained in:
parent
360f208037
commit
cfbc829df3
@ -88,15 +88,6 @@ from icefall.utils import (
|
|||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
|
|
||||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
|
||||||
if isinstance(model, DDP):
|
|
||||||
# get underlying nn.Module
|
|
||||||
model = model.module
|
|
||||||
for module in model.modules():
|
|
||||||
if hasattr(module, "batch_count"):
|
|
||||||
module.batch_count = batch_count
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -226,6 +217,13 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--freeze-modules",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Which modules to freeze during finetune"
|
||||||
|
)
|
||||||
|
|
||||||
parser = deepspeed.add_config_arguments(parser)
|
parser = deepspeed.add_config_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -583,6 +581,9 @@ def train_one_epoch(
|
|||||||
be set to 0.
|
be set to 0.
|
||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if name.startswith(params.freeze_modules):
|
||||||
|
module.eval()
|
||||||
|
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
@ -630,7 +631,6 @@ def train_one_epoch(
|
|||||||
model.step()
|
model.step()
|
||||||
else:
|
else:
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
set_batch_count(model, params.batch_idx_train)
|
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
@ -739,8 +739,19 @@ def run(rank, world_size, args):
|
|||||||
replace_whisper_encoder_forward()
|
replace_whisper_encoder_forward()
|
||||||
model = whisper.load_model(params.model_name, "cpu")
|
model = whisper.load_model(params.model_name, "cpu")
|
||||||
del model.alignment_heads
|
del model.alignment_heads
|
||||||
|
|
||||||
|
if params.freeze_modules is not None:
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
if name.startswith(params.freeze_modules):
|
||||||
|
p.requires_grad = False
|
||||||
|
logging.info(f"Do not update {name}")
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if name.startswith(params.freeze_modules):
|
||||||
|
module.eval()
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
num_trainable = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}. Total trainable parameters: {num_trainable}")
|
||||||
|
|
||||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||||
model.is_multilingual,
|
model.is_multilingual,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user