Merge b011448a287784fa14d44d9c6a5aeb8dc9705c27 into 34e40a86b33102576b3442329421178a487e3ea3

This commit is contained in:
Liyong.Guo 2023-09-22 16:43:00 +09:00 committed by GitHub
commit 3c3005bde0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 135 additions and 19 deletions

View File

@ -18,7 +18,7 @@ import contextlib
import logging
import random
from collections import defaultdict
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
from lhotse.utils import fix_random_seed
@ -132,6 +132,9 @@ class ScaledAdam(BatchedOptimizer):
Args:
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
Unlike common optimizers, which accept model.parameters() or groups of parameters(),
this optimizer could accept model.named_parameters() or groups of named_parameters().
See comments of function _get_names_of_parameters for its 4 possible cases.
lr: The learning rate. We will typically use a learning rate schedule that starts
at 0.03 and decreases over time, i.e. much higher than other common
optimizers.
@ -178,15 +181,8 @@ class ScaledAdam(BatchedOptimizer):
scalar_max=10.0,
size_update_period=4,
clipping_update_period=100,
parameters_names=None,
show_dominant_parameters=True,
):
assert parameters_names is not None, (
"Please prepare parameters_names,"
"which is a List[List[str]]. Each List[str] is for a group"
"and each str is for a parameter"
)
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
@ -200,10 +196,135 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period=clipping_update_period,
)
# If params only contains parameters or group of parameters,
# i.e when parameter names are not given,
# this flag will be set to False in funciton _get_names_of_parameters.
self.show_dominant_parameters = True
params, parameters_names = self._get_names_of_parameters(params)
super(ScaledAdam, self).__init__(params, defaults)
assert len(self.param_groups) == len(parameters_names)
self.parameters_names = parameters_names
self.show_dominant_parameters = show_dominant_parameters
def _get_names_of_parameters(
self, params_or_named_params
) -> Tuple[List[Dict], List[List[str]]]:
"""
Args:
params_or_named_params: according to the way ScaledAdam is initialized in train.py,
this argument could be one of following 4 cases,
case 1, a generator of parameter, e.g.:
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0)
case 2, a list of parameter groups with different config, e.g.:
model_param_groups = [
{'params': model.encoder.parameters(), 'lr': 0.05},
{'params': model.decoder.parameters(), 'lr': 0.01},
{'params': model.joiner.parameters(), 'lr': 0.03},
]
optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0)
case 3, a generator of named_parameter, e.g.:
optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0)
case 4, a list of named_parameter groups with different config, e.g.:
model_named_param_groups = [
{'params': model.encoder.named_parameters(), 'lr': 0.05},
{'params': model.decoder.named_parameters(), 'lr': 0.01},
{'params': model.joiner.named_parameters(), 'lr': 0.03},
]
optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0)
For case 1 and case 2, input params is used to initialize the underlying torch.optimizer.
For case 3 and case 4, firstly, names and params are extracted from input named_params,
then, these extracted params are used to initialize the underlying torch.optimizer,
and these extracted names are mainly used by function
`_show_gradient_dominating_parameter`
Returns:
Returns a tuple containing 2 elements:
- `param_groups` with type List[Dict], each Dict element is a parameter group.
An example of `param_groups` could be:
[
{'params': `one iterable of Parameter`, 'lr': 0.05},
{'params': `another iterable of Parameter`, 'lr': 0.08},
{'params': `a third iterable of Parameter`, 'lr': 0.1},
]
- `param_gruops_names` with type List[List[str]],
each `List[str]` is for a group['params'] in param_groups,
and each `str` is the name of a parameter.
A dummy name "foo" is related to each parameter,
if input are params without names, i.e. case 1 or case 2.
"""
# variable naming convention in this function:
# p is short for param.
# np is short for named_param.
# p_or_np is short for param_or_named_param.
# cur is short for current.
# group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
# groups is a List[group]
iterable_or_groups = list(params_or_named_params)
if len(iterable_or_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
# The first value of returned tuple.
param_groups = []
# The second value of returned tuple,
# a List[List[str]], each sub-List is for a group.
param_groups_names = []
if not isinstance(iterable_or_groups[0], dict):
# case 1 or case 3,
# the input is an iterable of parameter or named parameter.
param_iterable_cur_group = []
param_names_cur_group = []
for p_or_np in iterable_or_groups:
if isinstance(p_or_np, tuple):
# case 3
name, param = p_or_np
else:
# case 1
assert isinstance(p_or_np, torch.Tensor)
param = p_or_np
# Assign a dummy name as a placeholder
name = "foo"
self.show_dominant_parameters = False
param_iterable_cur_group.append(param)
param_names_cur_group.append(name)
param_groups.append({"params": param_iterable_cur_group})
param_groups_names.append(param_names_cur_group)
else:
# case 2 or case 4
# the input is groups of parameter or named parameter.
for p_or_np_cur_group in iterable_or_groups:
param_iterable_cur_group = []
param_names_cur_group = []
p_or_np_iterable = p_or_np_cur_group["params"]
for p_or_np in p_or_np_iterable:
if isinstance(p_or_np, tuple):
# case 4
name, param = p_or_np
else:
# case 2
assert isinstance(p_or_np, torch.Tensor)
param = p_or_np
# Assign a dummy name as a placeholder
name = "foo"
self.show_dominant_parameters = False
param_iterable_cur_group.append(param)
param_names_cur_group.append(name)
# The original `params` filed contains named_parameters.
# After following assignment,
# it will be changed to an iterable of parameter,
# and other fileds, if exist, are still original values.
# So param_groups could be used to initialize
# an underlying torch.Optimizer later.
p_or_np_cur_group["params"] = param_iterable_cur_group
param_groups.append(p_or_np_cur_group)
param_groups_names.append(param_names_cur_group)
return param_groups, param_groups_names
def __setstate__(self, state):
super(ScaledAdam, self).__setstate__(state)
@ -397,7 +518,7 @@ class ScaledAdam(BatchedOptimizer):
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
):
"""
Show information of parameter wihch dominanting tot_sumsq.
Show information of parameter wihch dominating tot_sumsq.
Args:
tuples: a list of tuples of (param, state, param_names)
@ -415,7 +536,7 @@ class ScaledAdam(BatchedOptimizer):
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
batch_sumsq_orig = batch_grad**2
# Dummpy values used by following `zip` statement.
# Dummy values used by following `zip` statement.
batch_rms_orig = torch.ones(p.shape[0])
else:
batch_rms_orig = state["param_rms"]
@ -448,11 +569,11 @@ class ScaledAdam(BatchedOptimizer):
dominant_grad,
) = sorted_by_proportion[dominant_param_name]
logging.info(
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
f"Parameter dominating tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
f"={dominant_sumsq:.3e},"
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
f" grad_sumsq={(dominant_grad**2).sum():.3e},"
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
)

View File

@ -1001,15 +1001,10 @@ def run(rank, world_size, args):
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
parameters_names = []
parameters_names.append(
[name_param_pair[0] for name_param_pair in model.named_parameters()]
)
optimizer = ScaledAdam(
model.parameters(),
model.named_parameters(),
lr=params.base_lr,
clipping_scale=2.0,
parameters_names=parameters_names,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)