Merge branch 'quantization' of github.com:yaozengwei/icefall into quantization

This commit is contained in:
yaozengwei 2023-03-23 20:06:55 +08:00
commit 1aa6fc0122

View File

@ -125,48 +125,48 @@ class BatchedOptimizer(Optimizer):
class ScaledAdam(BatchedOptimizer): class ScaledAdam(BatchedOptimizer):
""" """
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
proportional to the norm of that parameter; and also learn the scale of the parameter, proportional to the norm of that parameter; and also learn the scale of the parameter,
in log space, subject to upper and lower limits (as if we had factored each parameter as in log space, subject to upper and lower limits (as if we had factored each parameter as
param = underlying_param * log_scale.exp()) param = underlying_param * log_scale.exp())
Args: Args:
params: The parameters or param_groups to optimize (like other Optimizer subclasses) params: The parameters or param_groups to optimize (like other Optimizer subclasses)
lr: The learning rate. We will typically use a learning rate schedule that starts lr: The learning rate. We will typically use a learning rate schedule that starts
at 0.03 and decreases over time, i.e. much higher than other common at 0.03 and decreases over time, i.e. much higher than other common
optimizers. optimizers.
clipping_scale: (e.g. 2.0) clipping_scale: (e.g. 2.0)
A scale for gradient-clipping: if specified, the normalized gradients A scale for gradient-clipping: if specified, the normalized gradients
over the whole model will be clipped to have 2-norm equal to over the whole model will be clipped to have 2-norm equal to
`clipping_scale` times the median 2-norm over the most recent period `clipping_scale` times the median 2-norm over the most recent period
of `clipping_update_period` minibatches. By "normalized gradients", of `clipping_update_period` minibatches. By "normalized gradients",
we mean after multiplying by the rms parameter value for this tensor we mean after multiplying by the rms parameter value for this tensor
[for non-scalars]; this is appropriate because our update is scaled [for non-scalars]; this is appropriate because our update is scaled
by this quantity. by this quantity.
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
Must satisfy 0 < beta <= beta2 < 1. Must satisfy 0 < beta <= beta2 < 1.
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
scale of each parameter tensor and scalar parameters of the mode.. scale of each parameter tensor and scalar parameters of the mode..
If each parameter were decomposed If each parameter were decomposed
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
would be a the scaling factor on the learning rate of p_scale. would be a the scaling factor on the learning rate of p_scale.
eps: A general-purpose epsilon to prevent division by zero eps: A general-purpose epsilon to prevent division by zero
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be >= this value) parameter tensor to be >= this value)
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be <= this value) parameter tensor to be <= this value)
scalar_max: Maximum absolute value for scalar parameters (applicable if your scalar_max: Maximum absolute value for scalar parameters (applicable if your
model has any parameters with numel() == 1). model has any parameters with numel() == 1).
size_update_period: The periodicity, in steps, with which we update the size (scale) size_update_period: The periodicity, in steps, with which we update the size (scale)
of the parameter tensor. This is provided to save a little time of the parameter tensor. This is provided to save a little time
in the update. in the update.
clipping_update_period: if clipping_scale is specified, this is the period clipping_update_period: if clipping_scale is specified, this is the period
percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited. percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited.
p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the
parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile. parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.
""" """
def __init__( def __init__(
@ -681,10 +681,14 @@ class ScaledAdam(BatchedOptimizer):
p_exceed = p_abs > stored_percentiles # same shape as p p_exceed = p_abs > stored_percentiles # same shape as p
# The proportion that exceeds stored_percentiles # The proportion that exceeds stored_percentiles
proportion_exceed = p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel proportion_exceed = (
p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel
)
# Update store_percentiles # Update store_percentiles
update_sign = (proportion_exceed - percentile_limit).sign() update_sign = (proportion_exceed - percentile_limit).sign()
stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(min=1.0e-20) stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(
min=1.0e-20
)
# For these parameters that exceed stored_percentile, # For these parameters that exceed stored_percentile,
# flip the update sign if they would get larger absolute values # flip the update sign if they would get larger absolute values