diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index abd7fde8e..776c88a0e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -306,8 +306,8 @@ class ScaledAdam(BatchedOptimizer): param_groups_names.append(["foo"] * len(p_list)) else: assert "named_params" in cur_group - p_list = [ x[0] for x in cur_group["named_params"] ] - name_list = [ x[1] for x in cur_group["named_params"] ] + name_list = [ x[0] for x in cur_group["named_params"] ] + p_list = [ x[1] for x in cur_group["named_params"] ] del cur_group["named_params"] cur_group["params"] = p_list param_groups.append(cur_group) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 2454a87a2..b70591192 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -83,7 +83,13 @@ from icefall.checkpoint import ( from icefall.hooks import register_inf_check_hooks from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + get_parameter_groups_with_lrs +) LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler @@ -1075,8 +1081,9 @@ def run(rank, world_size, args): find_unused_parameters=True) optimizer = ScaledAdam( - model.named_parameters(), - lr=params.base_lr, + get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect clipping_scale=2.0, ) diff --git a/icefall/utils.py b/icefall/utils.py index 157bd90c4..92bf984fb 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -835,7 +835,7 @@ def measure_gradient_norms( norms[name] = val.item() return norms -def get_named_parameter_groups_with_lrs( +def get_parameter_groups_with_lrs( model: nn.Module, lr: float, include_names: bool = False) -> List[dict]: @@ -861,7 +861,7 @@ def get_named_parameter_groups_with_lrs( ... ] """ - named_modules = list(model.named_modules) + named_modules = list(model.named_modules()) # flat_lr_scale just contains the lr_scale explicitly specified # for each prefix of the name, e.g. 'encoder.layers.3', these need @@ -879,7 +879,7 @@ def get_named_parameter_groups_with_lrs( # otherwise a list of parameters for that learning rate. lr_to_params = defaultdict(list) - for name, parameter in module.named_parameters(): + for name, parameter in model.named_parameters(): split_name = name.split('.') # caution: as a special case, if the name is '', split_name will be [ '' ]. prefix = split_name[0]