Use get_parameter_groups_with_lr in train.py; bug fixes

This commit is contained in:
Daniel Povey 2023-01-05 14:02:01 +08:00
parent 1db509ea31
commit 0d7161ebec
3 changed files with 15 additions and 8 deletions

View File

@ -306,8 +306,8 @@ class ScaledAdam(BatchedOptimizer):
param_groups_names.append(["foo"] * len(p_list))
else:
assert "named_params" in cur_group
p_list = [ x[0] for x in cur_group["named_params"] ]
name_list = [ x[1] for x in cur_group["named_params"] ]
name_list = [ x[0] for x in cur_group["named_params"] ]
p_list = [ x[1] for x in cur_group["named_params"] ]
del cur_group["named_params"]
cur_group["params"] = p_list
param_groups.append(cur_group)

View File

@ -83,7 +83,13 @@ from icefall.checkpoint import (
from icefall.hooks import register_inf_check_hooks
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
get_parameter_groups_with_lrs
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -1075,8 +1081,9 @@ def run(rank, world_size, args):
find_unused_parameters=True)
optimizer = ScaledAdam(
model.named_parameters(),
lr=params.base_lr,
get_parameter_groups_with_lrs(
model, lr=params.base_lr, include_names=True),
lr=params.base_lr, # should have no effect
clipping_scale=2.0,
)

View File

@ -835,7 +835,7 @@ def measure_gradient_norms(
norms[name] = val.item()
return norms
def get_named_parameter_groups_with_lrs(
def get_parameter_groups_with_lrs(
model: nn.Module,
lr: float,
include_names: bool = False) -> List[dict]:
@ -861,7 +861,7 @@ def get_named_parameter_groups_with_lrs(
... ]
"""
named_modules = list(model.named_modules)
named_modules = list(model.named_modules())
# flat_lr_scale just contains the lr_scale explicitly specified
# for each prefix of the name, e.g. 'encoder.layers.3', these need
@ -879,7 +879,7 @@ def get_named_parameter_groups_with_lrs(
# otherwise a list of parameters for that learning rate.
lr_to_params = defaultdict(list)
for name, parameter in module.named_parameters():
for name, parameter in model.named_parameters():
split_name = name.split('.')
# caution: as a special case, if the name is '', split_name will be [ '' ].
prefix = split_name[0]