From df4b615993bea4bedca377765f58c14560c161e3 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 23 Mar 2023 17:27:57 +0800 Subject: [PATCH] Fix for style_check --- .../ASR/pruned_transducer_stateless7/optim.py | 88 ++++++++++--------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 4c9010124..330e23abf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -125,48 +125,48 @@ class BatchedOptimizer(Optimizer): class ScaledAdam(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period - percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited. - p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the - parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile. + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period + percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited. + p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the + parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile. """ def __init__( @@ -681,10 +681,14 @@ class ScaledAdam(BatchedOptimizer): p_exceed = p_abs > stored_percentiles # same shape as p # The proportion that exceeds stored_percentiles - proportion_exceed = p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel + proportion_exceed = ( + p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel + ) # Update store_percentiles update_sign = (proportion_exceed - percentile_limit).sign() - stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(min=1.0e-20) + stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_( + min=1.0e-20 + ) # For these parameters that exceed stored_percentile, # flip the update sign if they would get larger absolute values