mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Attempt to implement slower learning for downsampled modules
This commit is contained in:
parent
b7be18c2f8
commit
1db509ea31
@ -200,8 +200,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# i.e when parameter names are not given,
|
# i.e when parameter names are not given,
|
||||||
# this flag will be set to False in funciton _get_names_of_parameters.
|
# this flag will be set to False in funciton _get_names_of_parameters.
|
||||||
self.show_dominant_parameters = True
|
self.show_dominant_parameters = True
|
||||||
params, parameters_names = self._get_names_of_parameters(params)
|
param_groups, parameters_names = self._get_names_of_parameters(params)
|
||||||
super(ScaledAdam, self).__init__(params, defaults)
|
super(ScaledAdam, self).__init__(param_groups, defaults)
|
||||||
assert len(self.param_groups) == len(parameters_names)
|
assert len(self.param_groups) == len(parameters_names)
|
||||||
self.parameters_names = parameters_names
|
self.parameters_names = parameters_names
|
||||||
|
|
||||||
@ -228,9 +228,9 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
case 4, a list of named_parameter groups with different config, e.g.:
|
case 4, a list of named_parameter groups with different config, e.g.:
|
||||||
model_named_param_groups = [
|
model_named_param_groups = [
|
||||||
{'params': model.encoder.named_parameters(), 'lr': 0.05},
|
{'named_params': model.encoder.named_parameters(), 'lr': 0.05},
|
||||||
{'params': model.decoder.named_parameters(), 'lr': 0.01},
|
{'named_params': model.decoder.named_parameters(), 'lr': 0.01},
|
||||||
{'params': model.joiner.named_parameters(), 'lr': 0.03},
|
{'named_params': model.joiner.named_parameters(), 'lr': 0.03},
|
||||||
]
|
]
|
||||||
optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0)
|
optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0)
|
||||||
|
|
||||||
@ -267,7 +267,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
if len(iterable_or_groups) == 0:
|
if len(iterable_or_groups) == 0:
|
||||||
raise ValueError("optimizer got an empty parameter list")
|
raise ValueError("optimizer got an empty parameter list")
|
||||||
|
|
||||||
# The first value of returned tuple.
|
# The first value of returned tuple. A list of dicts containing at
|
||||||
|
# least 'params' as a key.
|
||||||
param_groups = []
|
param_groups = []
|
||||||
|
|
||||||
# The second value of returned tuple,
|
# The second value of returned tuple,
|
||||||
@ -297,33 +298,21 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
else:
|
else:
|
||||||
# case 2 or case 4
|
# case 2 or case 4
|
||||||
# the input is groups of parameter or named parameter.
|
# the input is groups of parameter or named parameter.
|
||||||
for p_or_np_cur_group in iterable_or_groups:
|
for cur_group in iterable_or_groups:
|
||||||
param_iterable_cur_group = []
|
if "params" in cur_group:
|
||||||
param_names_cur_group = []
|
for p in cur_group["params"]:
|
||||||
p_or_np_iterable = p_or_np_cur_group["params"]
|
assert isinstance(p, torch.Tensor)
|
||||||
for p_or_np in p_or_np_iterable:
|
param_groups.append(cur_group)
|
||||||
if isinstance(p_or_np, tuple):
|
param_groups_names.append(["foo"] * len(p_list))
|
||||||
# case 4
|
else:
|
||||||
name, param = p_or_np
|
assert "named_params" in cur_group
|
||||||
else:
|
p_list = [ x[0] for x in cur_group["named_params"] ]
|
||||||
# case 2
|
name_list = [ x[1] for x in cur_group["named_params"] ]
|
||||||
assert isinstance(p_or_np, torch.Tensor)
|
del cur_group["named_params"]
|
||||||
param = p_or_np
|
cur_group["params"] = p_list
|
||||||
# Assign a dummy name as a placeholder
|
param_groups.append(cur_group)
|
||||||
name = "foo"
|
param_groups_names.append(name_list)
|
||||||
self.show_dominant_parameters = False
|
|
||||||
param_iterable_cur_group.append(param)
|
|
||||||
param_names_cur_group.append(name)
|
|
||||||
|
|
||||||
# The original `params` filed contains named_parameters.
|
|
||||||
# After following assignment,
|
|
||||||
# it will be changed to an iterable of parameter,
|
|
||||||
# and other fileds, if exist, are still original values.
|
|
||||||
# So param_groups could be used to initialize
|
|
||||||
# an underlying torch.Optimizer later.
|
|
||||||
p_or_np_cur_group["params"] = param_iterable_cur_group
|
|
||||||
param_groups.append(p_or_np_cur_group)
|
|
||||||
param_groups_names.append(param_names_cur_group)
|
|
||||||
return param_groups, param_groups_names
|
return param_groups, param_groups_names
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
|
|||||||
@ -207,6 +207,9 @@ class Zipformer(EncoderInterface):
|
|||||||
downsample=downsampling_factor[i],
|
downsample=downsampling_factor[i],
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
|
# we are adding a new attribute here.
|
||||||
|
# this will be interpreted by get_named_parameter_groups_with_lrs().
|
||||||
|
encoder.lr_scale = downsampling_factor[i] ** -0.5
|
||||||
encoders.append(encoder)
|
encoders.append(encoder)
|
||||||
self.encoders = nn.ModuleList(encoders)
|
self.encoders = nn.ModuleList(encoders)
|
||||||
|
|
||||||
|
|||||||
@ -835,6 +835,69 @@ def measure_gradient_norms(
|
|||||||
norms[name] = val.item()
|
norms[name] = val.item()
|
||||||
return norms
|
return norms
|
||||||
|
|
||||||
|
def get_named_parameter_groups_with_lrs(
|
||||||
|
model: nn.Module,
|
||||||
|
lr: float,
|
||||||
|
include_names: bool = False) -> List[dict]:
|
||||||
|
"""
|
||||||
|
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of
|
||||||
|
named-parameters; we can, if needed, create a version without the names).
|
||||||
|
|
||||||
|
It provides a way to specifiy learning-rate scales inside the module, so that if
|
||||||
|
any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will
|
||||||
|
scale the LR of any parameters inside that module or its submodules. Note: you
|
||||||
|
can set module parameters outside the __init__ function, e.g.:
|
||||||
|
>>> a = nn.Linear(10, 10)
|
||||||
|
>>> a.lr_scale = 0.5
|
||||||
|
|
||||||
|
Returns: a list of dicts, of the following form:
|
||||||
|
if include_names == False:
|
||||||
|
[ { 'params': [ tensor1, tensor2, ... ], 'lr': 0.01 },
|
||||||
|
{ 'params': [ tensor3, tensor4, ... ], 'lr': 0.005 },
|
||||||
|
... ]
|
||||||
|
if include_names == true:
|
||||||
|
[ { 'named_params': [ (name1, tensor1, (name2, tensor2), ... ], 'lr': 0.01 },
|
||||||
|
{ 'named_params': [ (name3, tensor3), (name4, tensor4), ... ], 'lr': 0.005 },
|
||||||
|
... ]
|
||||||
|
|
||||||
|
"""
|
||||||
|
named_modules = list(model.named_modules)
|
||||||
|
|
||||||
|
# flat_lr_scale just contains the lr_scale explicitly specified
|
||||||
|
# for each prefix of the name, e.g. 'encoder.layers.3', these need
|
||||||
|
# to be multiplied for all prefix of the name of any given parameter.
|
||||||
|
flat_lr_scale = defaultdict(lambda: 1.0)
|
||||||
|
names = []
|
||||||
|
for name, m in model.named_modules():
|
||||||
|
names.append(name)
|
||||||
|
if hasattr(m, 'lr_scale'):
|
||||||
|
flat_lr_scale[name] = m.lr_scale
|
||||||
|
|
||||||
|
|
||||||
|
# lr_to_parames is a dict from learning rate (floating point) to: if
|
||||||
|
# include_names == true, a list of (name, parameter) for that learning rate;
|
||||||
|
# otherwise a list of parameters for that learning rate.
|
||||||
|
lr_to_params = defaultdict(list)
|
||||||
|
|
||||||
|
for name, parameter in module.named_parameters():
|
||||||
|
split_name = name.split('.')
|
||||||
|
# caution: as a special case, if the name is '', split_name will be [ '' ].
|
||||||
|
prefix = split_name[0]
|
||||||
|
cur_lr = lr * flat_lr_scale[prefix]
|
||||||
|
if prefix != '':
|
||||||
|
cur_lr *= flat_lr_scale['']
|
||||||
|
for part in split_name[1:]:
|
||||||
|
prefix = '.'.join([ prefix, part ])
|
||||||
|
cur_lr *= flat_lr_scale[prefix]
|
||||||
|
lr_to_params[cur_lr].append(
|
||||||
|
(name, parameter) if include_names else parameter
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_names:
|
||||||
|
return [ { 'named_params': pairs, 'lr': lr } for lr, pairs in lr_to_params.items() ]
|
||||||
|
else:
|
||||||
|
return [ { 'params' : params, 'lr': lr } for lr, params in lr_to_params.items() ]
|
||||||
|
|
||||||
|
|
||||||
def optim_step_and_measure_param_change(
|
def optim_step_and_measure_param_change(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user