diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 203861fc8..abd7fde8e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -200,8 +200,8 @@ class ScaledAdam(BatchedOptimizer): # i.e when parameter names are not given, # this flag will be set to False in funciton _get_names_of_parameters. self.show_dominant_parameters = True - params, parameters_names = self._get_names_of_parameters(params) - super(ScaledAdam, self).__init__(params, defaults) + param_groups, parameters_names = self._get_names_of_parameters(params) + super(ScaledAdam, self).__init__(param_groups, defaults) assert len(self.param_groups) == len(parameters_names) self.parameters_names = parameters_names @@ -228,9 +228,9 @@ class ScaledAdam(BatchedOptimizer): case 4, a list of named_parameter groups with different config, e.g.: model_named_param_groups = [ - {'params': model.encoder.named_parameters(), 'lr': 0.05}, - {'params': model.decoder.named_parameters(), 'lr': 0.01}, - {'params': model.joiner.named_parameters(), 'lr': 0.03}, + {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, + {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, + {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, ] optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) @@ -267,7 +267,8 @@ class ScaledAdam(BatchedOptimizer): if len(iterable_or_groups) == 0: raise ValueError("optimizer got an empty parameter list") - # The first value of returned tuple. + # The first value of returned tuple. A list of dicts containing at + # least 'params' as a key. param_groups = [] # The second value of returned tuple, @@ -297,33 +298,21 @@ class ScaledAdam(BatchedOptimizer): else: # case 2 or case 4 # the input is groups of parameter or named parameter. - for p_or_np_cur_group in iterable_or_groups: - param_iterable_cur_group = [] - param_names_cur_group = [] - p_or_np_iterable = p_or_np_cur_group["params"] - for p_or_np in p_or_np_iterable: - if isinstance(p_or_np, tuple): - # case 4 - name, param = p_or_np - else: - # case 2 - assert isinstance(p_or_np, torch.Tensor) - param = p_or_np - # Assign a dummy name as a placeholder - name = "foo" - self.show_dominant_parameters = False - param_iterable_cur_group.append(param) - param_names_cur_group.append(name) + for cur_group in iterable_or_groups: + if "params" in cur_group: + for p in cur_group["params"]: + assert isinstance(p, torch.Tensor) + param_groups.append(cur_group) + param_groups_names.append(["foo"] * len(p_list)) + else: + assert "named_params" in cur_group + p_list = [ x[0] for x in cur_group["named_params"] ] + name_list = [ x[1] for x in cur_group["named_params"] ] + del cur_group["named_params"] + cur_group["params"] = p_list + param_groups.append(cur_group) + param_groups_names.append(name_list) - # The original `params` filed contains named_parameters. - # After following assignment, - # it will be changed to an iterable of parameter, - # and other fileds, if exist, are still original values. - # So param_groups could be used to initialize - # an underlying torch.Optimizer later. - p_or_np_cur_group["params"] = param_iterable_cur_group - param_groups.append(p_or_np_cur_group) - param_groups_names.append(param_names_cur_group) return param_groups, param_groups_names def __setstate__(self, state): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index f7c9c487b..a33f768fb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -207,6 +207,9 @@ class Zipformer(EncoderInterface): downsample=downsampling_factor[i], dropout=dropout, ) + # we are adding a new attribute here. + # this will be interpreted by get_named_parameter_groups_with_lrs(). + encoder.lr_scale = downsampling_factor[i] ** -0.5 encoders.append(encoder) self.encoders = nn.ModuleList(encoders) diff --git a/icefall/utils.py b/icefall/utils.py index 09523fdf2..157bd90c4 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -835,6 +835,69 @@ def measure_gradient_norms( norms[name] = val.item() return norms +def get_named_parameter_groups_with_lrs( + model: nn.Module, + lr: float, + include_names: bool = False) -> List[dict]: + """ + This is for use with the ScaledAdam optimizers (more recent versions that accept lists of + named-parameters; we can, if needed, create a version without the names). + + It provides a way to specifiy learning-rate scales inside the module, so that if + any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will + scale the LR of any parameters inside that module or its submodules. Note: you + can set module parameters outside the __init__ function, e.g.: + >>> a = nn.Linear(10, 10) + >>> a.lr_scale = 0.5 + + Returns: a list of dicts, of the following form: + if include_names == False: + [ { 'params': [ tensor1, tensor2, ... ], 'lr': 0.01 }, + { 'params': [ tensor3, tensor4, ... ], 'lr': 0.005 }, + ... ] + if include_names == true: + [ { 'named_params': [ (name1, tensor1, (name2, tensor2), ... ], 'lr': 0.01 }, + { 'named_params': [ (name3, tensor3), (name4, tensor4), ... ], 'lr': 0.005 }, + ... ] + + """ + named_modules = list(model.named_modules) + + # flat_lr_scale just contains the lr_scale explicitly specified + # for each prefix of the name, e.g. 'encoder.layers.3', these need + # to be multiplied for all prefix of the name of any given parameter. + flat_lr_scale = defaultdict(lambda: 1.0) + names = [] + for name, m in model.named_modules(): + names.append(name) + if hasattr(m, 'lr_scale'): + flat_lr_scale[name] = m.lr_scale + + + # lr_to_parames is a dict from learning rate (floating point) to: if + # include_names == true, a list of (name, parameter) for that learning rate; + # otherwise a list of parameters for that learning rate. + lr_to_params = defaultdict(list) + + for name, parameter in module.named_parameters(): + split_name = name.split('.') + # caution: as a special case, if the name is '', split_name will be [ '' ]. + prefix = split_name[0] + cur_lr = lr * flat_lr_scale[prefix] + if prefix != '': + cur_lr *= flat_lr_scale[''] + for part in split_name[1:]: + prefix = '.'.join([ prefix, part ]) + cur_lr *= flat_lr_scale[prefix] + lr_to_params[cur_lr].append( + (name, parameter) if include_names else parameter + ) + + if include_names: + return [ { 'named_params': pairs, 'lr': lr } for lr, pairs in lr_to_params.items() ] + else: + return [ { 'params' : params, 'lr': lr } for lr, params in lr_to_params.items() ] + def optim_step_and_measure_param_change( model: nn.Module,