mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
Merge b011448a287784fa14d44d9c6a5aeb8dc9705c27 into 34e40a86b33102576b3442329421178a487e3ea3
This commit is contained in:
commit
3c3005bde0
@ -18,7 +18,7 @@ import contextlib
|
||||
import logging
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from lhotse.utils import fix_random_seed
|
||||
@ -132,6 +132,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
Args:
|
||||
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
||||
Unlike common optimizers, which accept model.parameters() or groups of parameters(),
|
||||
this optimizer could accept model.named_parameters() or groups of named_parameters().
|
||||
See comments of function _get_names_of_parameters for its 4 possible cases.
|
||||
lr: The learning rate. We will typically use a learning rate schedule that starts
|
||||
at 0.03 and decreases over time, i.e. much higher than other common
|
||||
optimizers.
|
||||
@ -178,15 +181,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
scalar_max=10.0,
|
||||
size_update_period=4,
|
||||
clipping_update_period=100,
|
||||
parameters_names=None,
|
||||
show_dominant_parameters=True,
|
||||
):
|
||||
|
||||
assert parameters_names is not None, (
|
||||
"Please prepare parameters_names,"
|
||||
"which is a List[List[str]]. Each List[str] is for a group"
|
||||
"and each str is for a parameter"
|
||||
)
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
clipping_scale=clipping_scale,
|
||||
@ -200,10 +196,135 @@ class ScaledAdam(BatchedOptimizer):
|
||||
clipping_update_period=clipping_update_period,
|
||||
)
|
||||
|
||||
# If params only contains parameters or group of parameters,
|
||||
# i.e when parameter names are not given,
|
||||
# this flag will be set to False in funciton _get_names_of_parameters.
|
||||
self.show_dominant_parameters = True
|
||||
params, parameters_names = self._get_names_of_parameters(params)
|
||||
super(ScaledAdam, self).__init__(params, defaults)
|
||||
assert len(self.param_groups) == len(parameters_names)
|
||||
self.parameters_names = parameters_names
|
||||
self.show_dominant_parameters = show_dominant_parameters
|
||||
|
||||
def _get_names_of_parameters(
|
||||
self, params_or_named_params
|
||||
) -> Tuple[List[Dict], List[List[str]]]:
|
||||
"""
|
||||
Args:
|
||||
params_or_named_params: according to the way ScaledAdam is initialized in train.py,
|
||||
this argument could be one of following 4 cases,
|
||||
case 1, a generator of parameter, e.g.:
|
||||
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0)
|
||||
|
||||
case 2, a list of parameter groups with different config, e.g.:
|
||||
model_param_groups = [
|
||||
{'params': model.encoder.parameters(), 'lr': 0.05},
|
||||
{'params': model.decoder.parameters(), 'lr': 0.01},
|
||||
{'params': model.joiner.parameters(), 'lr': 0.03},
|
||||
]
|
||||
optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0)
|
||||
|
||||
case 3, a generator of named_parameter, e.g.:
|
||||
optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0)
|
||||
|
||||
case 4, a list of named_parameter groups with different config, e.g.:
|
||||
model_named_param_groups = [
|
||||
{'params': model.encoder.named_parameters(), 'lr': 0.05},
|
||||
{'params': model.decoder.named_parameters(), 'lr': 0.01},
|
||||
{'params': model.joiner.named_parameters(), 'lr': 0.03},
|
||||
]
|
||||
optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0)
|
||||
|
||||
For case 1 and case 2, input params is used to initialize the underlying torch.optimizer.
|
||||
For case 3 and case 4, firstly, names and params are extracted from input named_params,
|
||||
then, these extracted params are used to initialize the underlying torch.optimizer,
|
||||
and these extracted names are mainly used by function
|
||||
`_show_gradient_dominating_parameter`
|
||||
|
||||
Returns:
|
||||
Returns a tuple containing 2 elements:
|
||||
- `param_groups` with type List[Dict], each Dict element is a parameter group.
|
||||
An example of `param_groups` could be:
|
||||
[
|
||||
{'params': `one iterable of Parameter`, 'lr': 0.05},
|
||||
{'params': `another iterable of Parameter`, 'lr': 0.08},
|
||||
{'params': `a third iterable of Parameter`, 'lr': 0.1},
|
||||
]
|
||||
- `param_gruops_names` with type List[List[str]],
|
||||
each `List[str]` is for a group['params'] in param_groups,
|
||||
and each `str` is the name of a parameter.
|
||||
A dummy name "foo" is related to each parameter,
|
||||
if input are params without names, i.e. case 1 or case 2.
|
||||
"""
|
||||
# variable naming convention in this function:
|
||||
# p is short for param.
|
||||
# np is short for named_param.
|
||||
# p_or_np is short for param_or_named_param.
|
||||
# cur is short for current.
|
||||
# group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
|
||||
# groups is a List[group]
|
||||
|
||||
iterable_or_groups = list(params_or_named_params)
|
||||
if len(iterable_or_groups) == 0:
|
||||
raise ValueError("optimizer got an empty parameter list")
|
||||
|
||||
# The first value of returned tuple.
|
||||
param_groups = []
|
||||
|
||||
# The second value of returned tuple,
|
||||
# a List[List[str]], each sub-List is for a group.
|
||||
param_groups_names = []
|
||||
|
||||
if not isinstance(iterable_or_groups[0], dict):
|
||||
# case 1 or case 3,
|
||||
# the input is an iterable of parameter or named parameter.
|
||||
param_iterable_cur_group = []
|
||||
param_names_cur_group = []
|
||||
for p_or_np in iterable_or_groups:
|
||||
if isinstance(p_or_np, tuple):
|
||||
# case 3
|
||||
name, param = p_or_np
|
||||
else:
|
||||
# case 1
|
||||
assert isinstance(p_or_np, torch.Tensor)
|
||||
param = p_or_np
|
||||
# Assign a dummy name as a placeholder
|
||||
name = "foo"
|
||||
self.show_dominant_parameters = False
|
||||
param_iterable_cur_group.append(param)
|
||||
param_names_cur_group.append(name)
|
||||
param_groups.append({"params": param_iterable_cur_group})
|
||||
param_groups_names.append(param_names_cur_group)
|
||||
else:
|
||||
# case 2 or case 4
|
||||
# the input is groups of parameter or named parameter.
|
||||
for p_or_np_cur_group in iterable_or_groups:
|
||||
param_iterable_cur_group = []
|
||||
param_names_cur_group = []
|
||||
p_or_np_iterable = p_or_np_cur_group["params"]
|
||||
for p_or_np in p_or_np_iterable:
|
||||
if isinstance(p_or_np, tuple):
|
||||
# case 4
|
||||
name, param = p_or_np
|
||||
else:
|
||||
# case 2
|
||||
assert isinstance(p_or_np, torch.Tensor)
|
||||
param = p_or_np
|
||||
# Assign a dummy name as a placeholder
|
||||
name = "foo"
|
||||
self.show_dominant_parameters = False
|
||||
param_iterable_cur_group.append(param)
|
||||
param_names_cur_group.append(name)
|
||||
|
||||
# The original `params` filed contains named_parameters.
|
||||
# After following assignment,
|
||||
# it will be changed to an iterable of parameter,
|
||||
# and other fileds, if exist, are still original values.
|
||||
# So param_groups could be used to initialize
|
||||
# an underlying torch.Optimizer later.
|
||||
p_or_np_cur_group["params"] = param_iterable_cur_group
|
||||
param_groups.append(p_or_np_cur_group)
|
||||
param_groups_names.append(param_names_cur_group)
|
||||
return param_groups, param_groups_names
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(ScaledAdam, self).__setstate__(state)
|
||||
@ -397,7 +518,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
||||
):
|
||||
"""
|
||||
Show information of parameter wihch dominanting tot_sumsq.
|
||||
Show information of parameter wihch dominating tot_sumsq.
|
||||
|
||||
Args:
|
||||
tuples: a list of tuples of (param, state, param_names)
|
||||
@ -415,7 +536,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
batch_grad = p.grad
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
batch_sumsq_orig = batch_grad**2
|
||||
# Dummpy values used by following `zip` statement.
|
||||
# Dummy values used by following `zip` statement.
|
||||
batch_rms_orig = torch.ones(p.shape[0])
|
||||
else:
|
||||
batch_rms_orig = state["param_rms"]
|
||||
@ -448,11 +569,11 @@ class ScaledAdam(BatchedOptimizer):
|
||||
dominant_grad,
|
||||
) = sorted_by_proportion[dominant_param_name]
|
||||
logging.info(
|
||||
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
||||
f"Parameter dominating tot_sumsq {dominant_param_name}"
|
||||
f" with proportion {dominant_proportion:.2f},"
|
||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||
f"={dominant_sumsq:.3e},"
|
||||
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
||||
f" grad_sumsq={(dominant_grad**2).sum():.3e},"
|
||||
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
||||
)
|
||||
|
||||
|
@ -1001,15 +1001,10 @@ def run(rank, world_size, args):
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
parameters_names = []
|
||||
parameters_names.append(
|
||||
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
||||
)
|
||||
optimizer = ScaledAdam(
|
||||
model.parameters(),
|
||||
model.named_parameters(),
|
||||
lr=params.base_lr,
|
||||
clipping_scale=2.0,
|
||||
parameters_names=parameters_names,
|
||||
)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user