mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
simplify interface
This commit is contained in:
parent
65d7192dca
commit
9e4ce825ac
@ -18,7 +18,7 @@ import contextlib
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
@ -132,6 +132,9 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
||||||
|
Unlike common optimizers, which accepts model.parameters() or groups of parameters(),
|
||||||
|
this optimizer could accept model.named_parameters() or groups of named_parameters().
|
||||||
|
See comments of function _get_names_of_parameters for its 4 possible cases.
|
||||||
lr: The learning rate. We will typically use a learning rate schedule that starts
|
lr: The learning rate. We will typically use a learning rate schedule that starts
|
||||||
at 0.03 and decreases over time, i.e. much higher than other common
|
at 0.03 and decreases over time, i.e. much higher than other common
|
||||||
optimizers.
|
optimizers.
|
||||||
@ -178,15 +181,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
scalar_max=10.0,
|
scalar_max=10.0,
|
||||||
size_update_period=4,
|
size_update_period=4,
|
||||||
clipping_update_period=100,
|
clipping_update_period=100,
|
||||||
parameters_names=None,
|
|
||||||
show_dominant_parameters=True,
|
|
||||||
):
|
):
|
||||||
|
|
||||||
assert parameters_names is not None, (
|
|
||||||
"Please prepare parameters_names,"
|
|
||||||
"which is a List[List[str]]. Each List[str] is for a group"
|
|
||||||
"and each str is for a parameter"
|
|
||||||
)
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
clipping_scale=clipping_scale,
|
clipping_scale=clipping_scale,
|
||||||
@ -200,10 +196,135 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
clipping_update_period=clipping_update_period,
|
clipping_update_period=clipping_update_period,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If params only contains parameters or group of parameters,
|
||||||
|
# i.e when parameter names are not given,
|
||||||
|
# this flag will be set to False in funciton _get_names_of_parameters.
|
||||||
|
self.show_dominant_parameters = True
|
||||||
|
params, parameters_names = self._get_names_of_parameters(params)
|
||||||
super(ScaledAdam, self).__init__(params, defaults)
|
super(ScaledAdam, self).__init__(params, defaults)
|
||||||
assert len(self.param_groups) == len(parameters_names)
|
assert len(self.param_groups) == len(parameters_names)
|
||||||
self.parameters_names = parameters_names
|
self.parameters_names = parameters_names
|
||||||
self.show_dominant_parameters = show_dominant_parameters
|
|
||||||
|
def _get_names_of_parameters(
|
||||||
|
self, params_or_named_params
|
||||||
|
) -> Tuple[List[Dict], List[List[str]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
params_or_named_params: according to the way ScaledAdam is initialized in train.py,
|
||||||
|
this argument could be one of following 4 cases,
|
||||||
|
case 1, a generator of parameter, e.g.:
|
||||||
|
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0)
|
||||||
|
|
||||||
|
case 2, a list of parameter groups with different config, e.g.:
|
||||||
|
model_param_groups = [
|
||||||
|
{'params': model.encoder.parameters(), 'lr': 0.05},
|
||||||
|
{'params': model.decoder.parameters(), 'lr': 0.01},
|
||||||
|
{'params': model.joiner.parameters(), 'lr': 0.03},
|
||||||
|
]
|
||||||
|
optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0)
|
||||||
|
|
||||||
|
case 3, a generator of named_parameter, e.g.:
|
||||||
|
optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0)
|
||||||
|
|
||||||
|
case 4, a list of named_parameter groups with different config, e.g.:
|
||||||
|
model_named_param_groups = [
|
||||||
|
{'params': model.encoder.named_parameters(), 'lr': 0.05},
|
||||||
|
{'params': model.decoder.named_parameters(), 'lr': 0.01},
|
||||||
|
{'params': model.joiner.named_parameters(), 'lr': 0.03},
|
||||||
|
]
|
||||||
|
optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0)
|
||||||
|
|
||||||
|
For case 1 and case 2, input params is used to initialize the underlying torch.optimizer.
|
||||||
|
For case 3 and case 4, firstly, names and params are extracted from input named_params,
|
||||||
|
then, these extracted params are used to initialize the underlying torch.optimizer,
|
||||||
|
and these extracted names are mainly used by function
|
||||||
|
`_show_gradient_dominating_parameter`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns a tuple containing 2 elements:
|
||||||
|
- `param_groups` with type List[Dict], each Dict element is a parameter group.
|
||||||
|
An example of `param_groups` could be:
|
||||||
|
[
|
||||||
|
{'params': `one iterable of Parameter`, 'lr': 0.05},
|
||||||
|
{'params': `another iterable of Parameter`, 'lr': 0.08},
|
||||||
|
{'params': `a third iterable of Parameter`, 'lr': 0.1},
|
||||||
|
]
|
||||||
|
- `param_gruops_names` with type List[List[str]],
|
||||||
|
each `List[str]` is for a group['params'] in param_groups,
|
||||||
|
and each `str` is the name of a parameter.
|
||||||
|
A dummy name "foo" is related to each parameter,
|
||||||
|
if input are params without names, i.e. case 1 or case 2.
|
||||||
|
"""
|
||||||
|
# variable naming convention in this function:
|
||||||
|
# p is short for param.
|
||||||
|
# np is short for named_param.
|
||||||
|
# p_or_np is short for param_or_named_param.
|
||||||
|
# curt is short for current.
|
||||||
|
# group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
|
||||||
|
# groups is a List[group]
|
||||||
|
|
||||||
|
iterable_or_groups = list(params_or_named_params)
|
||||||
|
if len(iterable_or_groups) == 0:
|
||||||
|
raise ValueError("optimizer got an empty parameter list")
|
||||||
|
|
||||||
|
# The first value of returned tuple.
|
||||||
|
param_groups = []
|
||||||
|
|
||||||
|
# The second value of returned tuple,
|
||||||
|
# a List[List[str]], each sub-List is for a group.
|
||||||
|
param_groups_names = []
|
||||||
|
|
||||||
|
if not isinstance(iterable_or_groups[0], dict):
|
||||||
|
# case 1 or case 3,
|
||||||
|
# the input is an iterable of parameter or named parameter.
|
||||||
|
param_iterable_curt_group = []
|
||||||
|
param_names_curt_group = []
|
||||||
|
for p_or_np in iterable_or_groups:
|
||||||
|
if isinstance(p_or_np, tuple):
|
||||||
|
# case 3
|
||||||
|
name, param = p_or_np
|
||||||
|
else:
|
||||||
|
# case 1
|
||||||
|
assert isinstance(p_or_np, torch.Tensor)
|
||||||
|
param = p_or_np
|
||||||
|
# Assign a dummy name as a placeholder
|
||||||
|
name = "foo"
|
||||||
|
self.show_dominant_parameters = False
|
||||||
|
param_iterable_curt_group.append(param)
|
||||||
|
param_names_curt_group.append(name)
|
||||||
|
param_groups.append({"params": param_iterable_curt_group})
|
||||||
|
param_groups_names.append(param_names_curt_group)
|
||||||
|
else:
|
||||||
|
# case 2 or case 4
|
||||||
|
# the input is groups of parameter or named parameter.
|
||||||
|
for p_or_np_curt_group in iterable_or_groups:
|
||||||
|
param_iterable_curt_group = []
|
||||||
|
param_names_curt_group = []
|
||||||
|
p_or_np_iterable = p_or_np_curt_group["params"]
|
||||||
|
for p_or_np in p_or_np_iterable:
|
||||||
|
if isinstance(p_or_np, tuple):
|
||||||
|
# case 4
|
||||||
|
name, param = p_or_np
|
||||||
|
else:
|
||||||
|
# case 2
|
||||||
|
assert isinstance(p_or_np, torch.Tensor)
|
||||||
|
param = p_or_np
|
||||||
|
# Assign a dummy name as a placeholder
|
||||||
|
name = "foo"
|
||||||
|
self.show_dominant_parameters = False
|
||||||
|
param_iterable_curt_group.append(param)
|
||||||
|
param_names_curt_group.append(name)
|
||||||
|
|
||||||
|
# The original `params` filed contains named_parameters.
|
||||||
|
# After following assignment,
|
||||||
|
# it will be changed to an iterable of parameter,
|
||||||
|
# and other fileds, if exist, are still original values.
|
||||||
|
# So param_groups could be used to initialize
|
||||||
|
# an underlying torch.Optimizer later.
|
||||||
|
p_or_np_curt_group["params"] = param_iterable_curt_group
|
||||||
|
param_groups.append(p_or_np_curt_group)
|
||||||
|
param_groups_names.append(param_names_curt_group)
|
||||||
|
return param_groups, param_groups_names
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
super(ScaledAdam, self).__setstate__(state)
|
super(ScaledAdam, self).__setstate__(state)
|
||||||
@ -398,7 +519,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Show information of parameter wihch dominanting tot_sumsq.
|
Show information of parameter wihch dominating tot_sumsq.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tuples: a list of tuples of (param, state, param_names)
|
tuples: a list of tuples of (param, state, param_names)
|
||||||
@ -416,7 +537,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
batch_grad = p.grad
|
batch_grad = p.grad
|
||||||
if p.numel() == p.shape[0]: # a batch of scalars
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
batch_sumsq_orig = batch_grad**2
|
batch_sumsq_orig = batch_grad**2
|
||||||
# Dummpy values used by following `zip` statement.
|
# Dummy values used by following `zip` statement.
|
||||||
batch_rms_orig = torch.ones(p.shape[0])
|
batch_rms_orig = torch.ones(p.shape[0])
|
||||||
else:
|
else:
|
||||||
batch_rms_orig = state["param_rms"]
|
batch_rms_orig = state["param_rms"]
|
||||||
@ -449,11 +570,11 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
dominant_grad,
|
dominant_grad,
|
||||||
) = sorted_by_proportion[dominant_param_name]
|
) = sorted_by_proportion[dominant_param_name]
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
f"Parameter dominating tot_sumsq {dominant_param_name}"
|
||||||
f" with proportion {dominant_proportion:.2f},"
|
f" with proportion {dominant_proportion:.2f},"
|
||||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||||
f"={dominant_sumsq:.3e},"
|
f"={dominant_sumsq:.3e},"
|
||||||
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
f" grad_sumsq={(dominant_grad**2).sum():.3e},"
|
||||||
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -988,15 +988,10 @@ def run(rank, world_size, args):
|
|||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
|
||||||
parameters_names = []
|
|
||||||
parameters_names.append(
|
|
||||||
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
|
||||||
)
|
|
||||||
optimizer = ScaledAdam(
|
optimizer = ScaledAdam(
|
||||||
model.parameters(),
|
model.named_parameters(),
|
||||||
lr=params.base_lr,
|
lr=params.base_lr,
|
||||||
clipping_scale=2.0,
|
clipping_scale=2.0,
|
||||||
parameters_names=parameters_names,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user