Fix for style_check

This commit is contained in:
Yifan Yang 2023-03-23 17:27:57 +08:00 committed by GitHub
parent 8c91b9d0cd
commit df4b615993
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -125,48 +125,48 @@ class BatchedOptimizer(Optimizer):
class ScaledAdam(BatchedOptimizer):
"""
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
proportional to the norm of that parameter; and also learn the scale of the parameter,
in log space, subject to upper and lower limits (as if we had factored each parameter as
param = underlying_param * log_scale.exp())
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
proportional to the norm of that parameter; and also learn the scale of the parameter,
in log space, subject to upper and lower limits (as if we had factored each parameter as
param = underlying_param * log_scale.exp())
Args:
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
lr: The learning rate. We will typically use a learning rate schedule that starts
at 0.03 and decreases over time, i.e. much higher than other common
optimizers.
clipping_scale: (e.g. 2.0)
A scale for gradient-clipping: if specified, the normalized gradients
over the whole model will be clipped to have 2-norm equal to
`clipping_scale` times the median 2-norm over the most recent period
of `clipping_update_period` minibatches. By "normalized gradients",
we mean after multiplying by the rms parameter value for this tensor
[for non-scalars]; this is appropriate because our update is scaled
by this quantity.
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
Must satisfy 0 < beta <= beta2 < 1.
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
scale of each parameter tensor and scalar parameters of the mode..
If each parameter were decomposed
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
would be a the scaling factor on the learning rate of p_scale.
eps: A general-purpose epsilon to prevent division by zero
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be >= this value)
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be <= this value)
scalar_max: Maximum absolute value for scalar parameters (applicable if your
model has any parameters with numel() == 1).
size_update_period: The periodicity, in steps, with which we update the size (scale)
of the parameter tensor. This is provided to save a little time
in the update.
clipping_update_period: if clipping_scale is specified, this is the period
percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited.
p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the
parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.
Args:
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
lr: The learning rate. We will typically use a learning rate schedule that starts
at 0.03 and decreases over time, i.e. much higher than other common
optimizers.
clipping_scale: (e.g. 2.0)
A scale for gradient-clipping: if specified, the normalized gradients
over the whole model will be clipped to have 2-norm equal to
`clipping_scale` times the median 2-norm over the most recent period
of `clipping_update_period` minibatches. By "normalized gradients",
we mean after multiplying by the rms parameter value for this tensor
[for non-scalars]; this is appropriate because our update is scaled
by this quantity.
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
Must satisfy 0 < beta <= beta2 < 1.
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
scale of each parameter tensor and scalar parameters of the mode..
If each parameter were decomposed
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
would be a the scaling factor on the learning rate of p_scale.
eps: A general-purpose epsilon to prevent division by zero
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be >= this value)
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be <= this value)
scalar_max: Maximum absolute value for scalar parameters (applicable if your
model has any parameters with numel() == 1).
size_update_period: The periodicity, in steps, with which we update the size (scale)
of the parameter tensor. This is provided to save a little time
in the update.
clipping_update_period: if clipping_scale is specified, this is the period
percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited.
p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the
parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.
"""
def __init__(
@ -681,10 +681,14 @@ class ScaledAdam(BatchedOptimizer):
p_exceed = p_abs > stored_percentiles # same shape as p
# The proportion that exceeds stored_percentiles
proportion_exceed = p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel
proportion_exceed = (
p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel
)
# Update store_percentiles
update_sign = (proportion_exceed - percentile_limit).sign()
stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(min=1.0e-20)
stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(
min=1.0e-20
)
# For these parameters that exceed stored_percentile,
# flip the update sign if they would get larger absolute values