Use get_parameter_groups_with_lr in train.py; bug fixes
This commit is contained in:
parent
1db509ea31
commit
0d7161ebec
@ -306,8 +306,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
param_groups_names.append(["foo"] * len(p_list))
|
param_groups_names.append(["foo"] * len(p_list))
|
||||||
else:
|
else:
|
||||||
assert "named_params" in cur_group
|
assert "named_params" in cur_group
|
||||||
p_list = [ x[0] for x in cur_group["named_params"] ]
|
name_list = [ x[0] for x in cur_group["named_params"] ]
|
||||||
name_list = [ x[1] for x in cur_group["named_params"] ]
|
p_list = [ x[1] for x in cur_group["named_params"] ]
|
||||||
del cur_group["named_params"]
|
del cur_group["named_params"]
|
||||||
cur_group["params"] = p_list
|
cur_group["params"] = p_list
|
||||||
param_groups.append(cur_group)
|
param_groups.append(cur_group)
|
||||||
|
|||||||
@ -83,7 +83,13 @@ from icefall.checkpoint import (
|
|||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
get_parameter_groups_with_lrs
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[
|
LRSchedulerType = Union[
|
||||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
@ -1075,8 +1081,9 @@ def run(rank, world_size, args):
|
|||||||
find_unused_parameters=True)
|
find_unused_parameters=True)
|
||||||
|
|
||||||
optimizer = ScaledAdam(
|
optimizer = ScaledAdam(
|
||||||
model.named_parameters(),
|
get_parameter_groups_with_lrs(
|
||||||
lr=params.base_lr,
|
model, lr=params.base_lr, include_names=True),
|
||||||
|
lr=params.base_lr, # should have no effect
|
||||||
clipping_scale=2.0,
|
clipping_scale=2.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -835,7 +835,7 @@ def measure_gradient_norms(
|
|||||||
norms[name] = val.item()
|
norms[name] = val.item()
|
||||||
return norms
|
return norms
|
||||||
|
|
||||||
def get_named_parameter_groups_with_lrs(
|
def get_parameter_groups_with_lrs(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lr: float,
|
lr: float,
|
||||||
include_names: bool = False) -> List[dict]:
|
include_names: bool = False) -> List[dict]:
|
||||||
@ -861,7 +861,7 @@ def get_named_parameter_groups_with_lrs(
|
|||||||
... ]
|
... ]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
named_modules = list(model.named_modules)
|
named_modules = list(model.named_modules())
|
||||||
|
|
||||||
# flat_lr_scale just contains the lr_scale explicitly specified
|
# flat_lr_scale just contains the lr_scale explicitly specified
|
||||||
# for each prefix of the name, e.g. 'encoder.layers.3', these need
|
# for each prefix of the name, e.g. 'encoder.layers.3', these need
|
||||||
@ -879,7 +879,7 @@ def get_named_parameter_groups_with_lrs(
|
|||||||
# otherwise a list of parameters for that learning rate.
|
# otherwise a list of parameters for that learning rate.
|
||||||
lr_to_params = defaultdict(list)
|
lr_to_params = defaultdict(list)
|
||||||
|
|
||||||
for name, parameter in module.named_parameters():
|
for name, parameter in model.named_parameters():
|
||||||
split_name = name.split('.')
|
split_name = name.split('.')
|
||||||
# caution: as a special case, if the name is '', split_name will be [ '' ].
|
# caution: as a special case, if the name is '', split_name will be [ '' ].
|
||||||
prefix = split_name[0]
|
prefix = split_name[0]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user