mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Merge 1aa6fc0122981d3a248b08edd2bf4d71f5e27bd0 into abd9437e6d5419a497707748eb935e50976c3b7b
This commit is contained in:
commit
a84bd6ae1f
@ -16,6 +16,7 @@
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@ -124,45 +125,48 @@ class BatchedOptimizer(Optimizer):
|
||||
|
||||
class ScaledAdam(BatchedOptimizer):
|
||||
"""
|
||||
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
||||
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
||||
in log space, subject to upper and lower limits (as if we had factored each parameter as
|
||||
param = underlying_param * log_scale.exp())
|
||||
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
||||
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
||||
in log space, subject to upper and lower limits (as if we had factored each parameter as
|
||||
param = underlying_param * log_scale.exp())
|
||||
|
||||
|
||||
Args:
|
||||
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
||||
lr: The learning rate. We will typically use a learning rate schedule that starts
|
||||
at 0.03 and decreases over time, i.e. much higher than other common
|
||||
optimizers.
|
||||
clipping_scale: (e.g. 2.0)
|
||||
A scale for gradient-clipping: if specified, the normalized gradients
|
||||
over the whole model will be clipped to have 2-norm equal to
|
||||
`clipping_scale` times the median 2-norm over the most recent period
|
||||
of `clipping_update_period` minibatches. By "normalized gradients",
|
||||
we mean after multiplying by the rms parameter value for this tensor
|
||||
[for non-scalars]; this is appropriate because our update is scaled
|
||||
by this quantity.
|
||||
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
|
||||
Must satisfy 0 < beta <= beta2 < 1.
|
||||
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
|
||||
scale of each parameter tensor and scalar parameters of the mode..
|
||||
If each parameter were decomposed
|
||||
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
|
||||
would be a the scaling factor on the learning rate of p_scale.
|
||||
eps: A general-purpose epsilon to prevent division by zero
|
||||
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
||||
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
||||
parameter tensor to be >= this value)
|
||||
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
||||
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
||||
parameter tensor to be <= this value)
|
||||
scalar_max: Maximum absolute value for scalar parameters (applicable if your
|
||||
model has any parameters with numel() == 1).
|
||||
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
||||
of the parameter tensor. This is provided to save a little time
|
||||
in the update.
|
||||
clipping_update_period: if clipping_scale is specified, this is the period
|
||||
Args:
|
||||
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
||||
lr: The learning rate. We will typically use a learning rate schedule that starts
|
||||
at 0.03 and decreases over time, i.e. much higher than other common
|
||||
optimizers.
|
||||
clipping_scale: (e.g. 2.0)
|
||||
A scale for gradient-clipping: if specified, the normalized gradients
|
||||
over the whole model will be clipped to have 2-norm equal to
|
||||
`clipping_scale` times the median 2-norm over the most recent period
|
||||
of `clipping_update_period` minibatches. By "normalized gradients",
|
||||
we mean after multiplying by the rms parameter value for this tensor
|
||||
[for non-scalars]; this is appropriate because our update is scaled
|
||||
by this quantity.
|
||||
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
|
||||
Must satisfy 0 < beta <= beta2 < 1.
|
||||
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
|
||||
scale of each parameter tensor and scalar parameters of the mode..
|
||||
If each parameter were decomposed
|
||||
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
|
||||
would be a the scaling factor on the learning rate of p_scale.
|
||||
eps: A general-purpose epsilon to prevent division by zero
|
||||
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
||||
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
||||
parameter tensor to be >= this value)
|
||||
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
||||
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
||||
parameter tensor to be <= this value)
|
||||
scalar_max: Maximum absolute value for scalar parameters (applicable if your
|
||||
model has any parameters with numel() == 1).
|
||||
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
||||
of the parameter tensor. This is provided to save a little time
|
||||
in the update.
|
||||
clipping_update_period: if clipping_scale is specified, this is the period
|
||||
percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited.
|
||||
p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the
|
||||
parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -180,6 +184,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
clipping_update_period=100,
|
||||
parameters_names=None,
|
||||
show_dominant_parameters=True,
|
||||
percentile_limit=0.05,
|
||||
p_limit_values=0.0,
|
||||
):
|
||||
assert parameters_names is not None, (
|
||||
"Please prepare parameters_names,"
|
||||
@ -197,6 +203,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
scalar_max=scalar_max,
|
||||
size_update_period=size_update_period,
|
||||
clipping_update_period=clipping_update_period,
|
||||
percentile_limit=percentile_limit,
|
||||
p_limit_values=p_limit_values,
|
||||
)
|
||||
|
||||
super(ScaledAdam, self).__init__(params, defaults)
|
||||
@ -292,6 +300,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
size_update_period, *param_rms.shape, **kwargs
|
||||
)
|
||||
|
||||
if group["p_limit_values"] > 0:
|
||||
state["stored_percentiles"] = torch.ones_like(param_rms)
|
||||
|
||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
@ -598,7 +609,12 @@ class ScaledAdam(BatchedOptimizer):
|
||||
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
|
||||
|
||||
delta = state["delta"]
|
||||
delta.add_(grad * alpha)
|
||||
|
||||
if random.random() >= group["p_limit_values"]:
|
||||
delta.add_(grad * alpha)
|
||||
else:
|
||||
delta.add_((grad * alpha) * self._limit_values_sign(group, p, grad, state))
|
||||
|
||||
p.add_(delta)
|
||||
|
||||
def _step_scalar(self, group: dict, p: Tensor, state: dict):
|
||||
@ -625,6 +641,67 @@ class ScaledAdam(BatchedOptimizer):
|
||||
p.clamp_(min=-scalar_max, max=scalar_max)
|
||||
p.add_(delta)
|
||||
|
||||
def _limit_values_sign(self, group: dict, p: Tensor, grad: Tensor, state: dict):
|
||||
"""Decide whether to modify the sign of the cerrent update.
|
||||
|
||||
For each parameter tensor in p, we store the centain percentile (e.g., 95%),
|
||||
say stored_percentiles.
|
||||
If more than 5% parameter absolute values are larger than stored_percentiles:
|
||||
Increase stored_percentiles by (1.0 + lr / p_limit_values).
|
||||
Else:
|
||||
Decrease stored_percentiles by (1.0 - lr / p_limit_values).
|
||||
For all parameters whose absolute values are larger than stored_percentiles,
|
||||
modify the sign of the update so that it points 'inward'.
|
||||
|
||||
Args:
|
||||
group: A dict which will be used to look up configuration values
|
||||
p: The parameter to be updated
|
||||
grad: The grad of p
|
||||
state: The state-dict corresponding to parameter p
|
||||
|
||||
Returns: A tensor with same shape as p, filled with 1 or -1.
|
||||
"""
|
||||
lr = group["lr"]
|
||||
# The probability to limit the absolute values
|
||||
p_limit_values = group["p_limit_values"] # e.g., 0.1
|
||||
# The parameter absolute values over 1-percentile_limit percentile will be limited.
|
||||
percentile_limit = group["percentile_limit"] # e.g., 0.05
|
||||
# It stores the percentiles (i.e, 1-percentile_limit) of the parameter absolute values,
|
||||
# with a shape like (batch_size, 1, 1, 1, 1)
|
||||
stored_percentiles = state["stored_percentiles"]
|
||||
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
p_abs = p.abs()
|
||||
|
||||
p_exceed = p_abs > stored_percentiles # same shape as p
|
||||
# The proportion that exceeds stored_percentiles
|
||||
proportion_exceed = (
|
||||
p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel
|
||||
)
|
||||
# Update store_percentiles
|
||||
update_sign = (proportion_exceed - percentile_limit).sign()
|
||||
stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(
|
||||
min=1.0e-20
|
||||
)
|
||||
|
||||
# For these parameters that exceed stored_percentile,
|
||||
# flip the update sign if they would get larger absolute values
|
||||
limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0) * (p_abs > 10)
|
||||
|
||||
# TODO: will remove the log bellow
|
||||
if random.random() < 0.1:
|
||||
mask = limit_sign == -1
|
||||
if mask.any():
|
||||
logging.info(f"p.shape: {p.shape}")
|
||||
logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}")
|
||||
logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}")
|
||||
logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}")
|
||||
proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel
|
||||
logging.info(f"proportion_sign_change: {proportion_sign_change}")
|
||||
|
||||
return limit_sign
|
||||
|
||||
|
||||
class LRScheduler(object):
|
||||
"""
|
||||
|
@ -376,6 +376,22 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--p-limit-values",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""The probability (e.g., 0.1) to modify the update sign, so as to limit the
|
||||
parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--percentile-limit",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="""The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile
|
||||
will be limited.""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -1009,6 +1025,8 @@ def run(rank, world_size, args):
|
||||
lr=params.base_lr,
|
||||
clipping_scale=2.0,
|
||||
parameters_names=parameters_names,
|
||||
p_limit_values=params.p_limit_values,
|
||||
percentile_limit=params.percentile_limit,
|
||||
)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user