mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge 1aa6fc0122981d3a248b08edd2bf4d71f5e27bd0 into abd9437e6d5419a497707748eb935e50976c3b7b
This commit is contained in:
commit
a84bd6ae1f
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
@ -124,45 +125,48 @@ class BatchedOptimizer(Optimizer):
|
|||||||
|
|
||||||
class ScaledAdam(BatchedOptimizer):
|
class ScaledAdam(BatchedOptimizer):
|
||||||
"""
|
"""
|
||||||
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
||||||
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
||||||
in log space, subject to upper and lower limits (as if we had factored each parameter as
|
in log space, subject to upper and lower limits (as if we had factored each parameter as
|
||||||
param = underlying_param * log_scale.exp())
|
param = underlying_param * log_scale.exp())
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
||||||
lr: The learning rate. We will typically use a learning rate schedule that starts
|
lr: The learning rate. We will typically use a learning rate schedule that starts
|
||||||
at 0.03 and decreases over time, i.e. much higher than other common
|
at 0.03 and decreases over time, i.e. much higher than other common
|
||||||
optimizers.
|
optimizers.
|
||||||
clipping_scale: (e.g. 2.0)
|
clipping_scale: (e.g. 2.0)
|
||||||
A scale for gradient-clipping: if specified, the normalized gradients
|
A scale for gradient-clipping: if specified, the normalized gradients
|
||||||
over the whole model will be clipped to have 2-norm equal to
|
over the whole model will be clipped to have 2-norm equal to
|
||||||
`clipping_scale` times the median 2-norm over the most recent period
|
`clipping_scale` times the median 2-norm over the most recent period
|
||||||
of `clipping_update_period` minibatches. By "normalized gradients",
|
of `clipping_update_period` minibatches. By "normalized gradients",
|
||||||
we mean after multiplying by the rms parameter value for this tensor
|
we mean after multiplying by the rms parameter value for this tensor
|
||||||
[for non-scalars]; this is appropriate because our update is scaled
|
[for non-scalars]; this is appropriate because our update is scaled
|
||||||
by this quantity.
|
by this quantity.
|
||||||
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
|
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
|
||||||
Must satisfy 0 < beta <= beta2 < 1.
|
Must satisfy 0 < beta <= beta2 < 1.
|
||||||
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
|
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
|
||||||
scale of each parameter tensor and scalar parameters of the mode..
|
scale of each parameter tensor and scalar parameters of the mode..
|
||||||
If each parameter were decomposed
|
If each parameter were decomposed
|
||||||
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
|
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
|
||||||
would be a the scaling factor on the learning rate of p_scale.
|
would be a the scaling factor on the learning rate of p_scale.
|
||||||
eps: A general-purpose epsilon to prevent division by zero
|
eps: A general-purpose epsilon to prevent division by zero
|
||||||
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
||||||
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
||||||
parameter tensor to be >= this value)
|
parameter tensor to be >= this value)
|
||||||
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
||||||
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
||||||
parameter tensor to be <= this value)
|
parameter tensor to be <= this value)
|
||||||
scalar_max: Maximum absolute value for scalar parameters (applicable if your
|
scalar_max: Maximum absolute value for scalar parameters (applicable if your
|
||||||
model has any parameters with numel() == 1).
|
model has any parameters with numel() == 1).
|
||||||
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
||||||
of the parameter tensor. This is provided to save a little time
|
of the parameter tensor. This is provided to save a little time
|
||||||
in the update.
|
in the update.
|
||||||
clipping_update_period: if clipping_scale is specified, this is the period
|
clipping_update_period: if clipping_scale is specified, this is the period
|
||||||
|
percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited.
|
||||||
|
p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the
|
||||||
|
parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -180,6 +184,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
clipping_update_period=100,
|
clipping_update_period=100,
|
||||||
parameters_names=None,
|
parameters_names=None,
|
||||||
show_dominant_parameters=True,
|
show_dominant_parameters=True,
|
||||||
|
percentile_limit=0.05,
|
||||||
|
p_limit_values=0.0,
|
||||||
):
|
):
|
||||||
assert parameters_names is not None, (
|
assert parameters_names is not None, (
|
||||||
"Please prepare parameters_names,"
|
"Please prepare parameters_names,"
|
||||||
@ -197,6 +203,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
scalar_max=scalar_max,
|
scalar_max=scalar_max,
|
||||||
size_update_period=size_update_period,
|
size_update_period=size_update_period,
|
||||||
clipping_update_period=clipping_update_period,
|
clipping_update_period=clipping_update_period,
|
||||||
|
percentile_limit=percentile_limit,
|
||||||
|
p_limit_values=p_limit_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
super(ScaledAdam, self).__init__(params, defaults)
|
super(ScaledAdam, self).__init__(params, defaults)
|
||||||
@ -292,6 +300,9 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
size_update_period, *param_rms.shape, **kwargs
|
size_update_period, *param_rms.shape, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if group["p_limit_values"] > 0:
|
||||||
|
state["stored_percentiles"] = torch.ones_like(param_rms)
|
||||||
|
|
||||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||||
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||||
|
|
||||||
@ -598,7 +609,12 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
|
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
delta.add_(grad * alpha)
|
|
||||||
|
if random.random() >= group["p_limit_values"]:
|
||||||
|
delta.add_(grad * alpha)
|
||||||
|
else:
|
||||||
|
delta.add_((grad * alpha) * self._limit_values_sign(group, p, grad, state))
|
||||||
|
|
||||||
p.add_(delta)
|
p.add_(delta)
|
||||||
|
|
||||||
def _step_scalar(self, group: dict, p: Tensor, state: dict):
|
def _step_scalar(self, group: dict, p: Tensor, state: dict):
|
||||||
@ -625,6 +641,67 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
p.clamp_(min=-scalar_max, max=scalar_max)
|
p.clamp_(min=-scalar_max, max=scalar_max)
|
||||||
p.add_(delta)
|
p.add_(delta)
|
||||||
|
|
||||||
|
def _limit_values_sign(self, group: dict, p: Tensor, grad: Tensor, state: dict):
|
||||||
|
"""Decide whether to modify the sign of the cerrent update.
|
||||||
|
|
||||||
|
For each parameter tensor in p, we store the centain percentile (e.g., 95%),
|
||||||
|
say stored_percentiles.
|
||||||
|
If more than 5% parameter absolute values are larger than stored_percentiles:
|
||||||
|
Increase stored_percentiles by (1.0 + lr / p_limit_values).
|
||||||
|
Else:
|
||||||
|
Decrease stored_percentiles by (1.0 - lr / p_limit_values).
|
||||||
|
For all parameters whose absolute values are larger than stored_percentiles,
|
||||||
|
modify the sign of the update so that it points 'inward'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group: A dict which will be used to look up configuration values
|
||||||
|
p: The parameter to be updated
|
||||||
|
grad: The grad of p
|
||||||
|
state: The state-dict corresponding to parameter p
|
||||||
|
|
||||||
|
Returns: A tensor with same shape as p, filled with 1 or -1.
|
||||||
|
"""
|
||||||
|
lr = group["lr"]
|
||||||
|
# The probability to limit the absolute values
|
||||||
|
p_limit_values = group["p_limit_values"] # e.g., 0.1
|
||||||
|
# The parameter absolute values over 1-percentile_limit percentile will be limited.
|
||||||
|
percentile_limit = group["percentile_limit"] # e.g., 0.05
|
||||||
|
# It stores the percentiles (i.e, 1-percentile_limit) of the parameter absolute values,
|
||||||
|
# with a shape like (batch_size, 1, 1, 1, 1)
|
||||||
|
stored_percentiles = state["stored_percentiles"]
|
||||||
|
|
||||||
|
batch_size = p.shape[0]
|
||||||
|
numel = p.numel() // batch_size
|
||||||
|
p_abs = p.abs()
|
||||||
|
|
||||||
|
p_exceed = p_abs > stored_percentiles # same shape as p
|
||||||
|
# The proportion that exceeds stored_percentiles
|
||||||
|
proportion_exceed = (
|
||||||
|
p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel
|
||||||
|
)
|
||||||
|
# Update store_percentiles
|
||||||
|
update_sign = (proportion_exceed - percentile_limit).sign()
|
||||||
|
stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(
|
||||||
|
min=1.0e-20
|
||||||
|
)
|
||||||
|
|
||||||
|
# For these parameters that exceed stored_percentile,
|
||||||
|
# flip the update sign if they would get larger absolute values
|
||||||
|
limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0) * (p_abs > 10)
|
||||||
|
|
||||||
|
# TODO: will remove the log bellow
|
||||||
|
if random.random() < 0.1:
|
||||||
|
mask = limit_sign == -1
|
||||||
|
if mask.any():
|
||||||
|
logging.info(f"p.shape: {p.shape}")
|
||||||
|
logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}")
|
||||||
|
logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}")
|
||||||
|
logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}")
|
||||||
|
proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel
|
||||||
|
logging.info(f"proportion_sign_change: {proportion_sign_change}")
|
||||||
|
|
||||||
|
return limit_sign
|
||||||
|
|
||||||
|
|
||||||
class LRScheduler(object):
|
class LRScheduler(object):
|
||||||
"""
|
"""
|
||||||
|
@ -376,6 +376,22 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--p-limit-values",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""The probability (e.g., 0.1) to modify the update sign, so as to limit the
|
||||||
|
parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--percentile-limit",
|
||||||
|
type=float,
|
||||||
|
default=0.05,
|
||||||
|
help="""The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile
|
||||||
|
will be limited.""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -1009,6 +1025,8 @@ def run(rank, world_size, args):
|
|||||||
lr=params.base_lr,
|
lr=params.base_lr,
|
||||||
clipping_scale=2.0,
|
clipping_scale=2.0,
|
||||||
parameters_names=parameters_names,
|
parameters_names=parameters_names,
|
||||||
|
p_limit_values=params.p_limit_values,
|
||||||
|
percentile_limit=params.percentile_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user