mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use get_parameter_groups_with_lr in train.py; bug fixes
This commit is contained in:
parent
1db509ea31
commit
0d7161ebec
@ -306,8 +306,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
param_groups_names.append(["foo"] * len(p_list))
|
||||
else:
|
||||
assert "named_params" in cur_group
|
||||
p_list = [ x[0] for x in cur_group["named_params"] ]
|
||||
name_list = [ x[1] for x in cur_group["named_params"] ]
|
||||
name_list = [ x[0] for x in cur_group["named_params"] ]
|
||||
p_list = [ x[1] for x in cur_group["named_params"] ]
|
||||
del cur_group["named_params"]
|
||||
cur_group["params"] = p_list
|
||||
param_groups.append(cur_group)
|
||||
|
||||
@ -83,7 +83,13 @@ from icefall.checkpoint import (
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
get_parameter_groups_with_lrs
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
@ -1075,8 +1081,9 @@ def run(rank, world_size, args):
|
||||
find_unused_parameters=True)
|
||||
|
||||
optimizer = ScaledAdam(
|
||||
model.named_parameters(),
|
||||
lr=params.base_lr,
|
||||
get_parameter_groups_with_lrs(
|
||||
model, lr=params.base_lr, include_names=True),
|
||||
lr=params.base_lr, # should have no effect
|
||||
clipping_scale=2.0,
|
||||
)
|
||||
|
||||
|
||||
@ -835,7 +835,7 @@ def measure_gradient_norms(
|
||||
norms[name] = val.item()
|
||||
return norms
|
||||
|
||||
def get_named_parameter_groups_with_lrs(
|
||||
def get_parameter_groups_with_lrs(
|
||||
model: nn.Module,
|
||||
lr: float,
|
||||
include_names: bool = False) -> List[dict]:
|
||||
@ -861,7 +861,7 @@ def get_named_parameter_groups_with_lrs(
|
||||
... ]
|
||||
|
||||
"""
|
||||
named_modules = list(model.named_modules)
|
||||
named_modules = list(model.named_modules())
|
||||
|
||||
# flat_lr_scale just contains the lr_scale explicitly specified
|
||||
# for each prefix of the name, e.g. 'encoder.layers.3', these need
|
||||
@ -879,7 +879,7 @@ def get_named_parameter_groups_with_lrs(
|
||||
# otherwise a list of parameters for that learning rate.
|
||||
lr_to_params = defaultdict(list)
|
||||
|
||||
for name, parameter in module.named_parameters():
|
||||
for name, parameter in model.named_parameters():
|
||||
split_name = name.split('.')
|
||||
# caution: as a special case, if the name is '', split_name will be [ '' ].
|
||||
prefix = split_name[0]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user