diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 8ab3589da..ee7c2d44e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -16,6 +16,7 @@ import contextlib import logging +import math import random from collections import defaultdict from typing import List, Optional, Tuple, Union @@ -124,45 +125,48 @@ class BatchedOptimizer(Optimizer): class ScaledAdam(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period + percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited. + p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the + parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile. """ def __init__( @@ -180,6 +184,8 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period=100, parameters_names=None, show_dominant_parameters=True, + percentile_limit=0.05, + p_limit_values=0.0, ): assert parameters_names is not None, ( "Please prepare parameters_names," @@ -197,6 +203,8 @@ class ScaledAdam(BatchedOptimizer): scalar_max=scalar_max, size_update_period=size_update_period, clipping_update_period=clipping_update_period, + percentile_limit=percentile_limit, + p_limit_values=p_limit_values, ) super(ScaledAdam, self).__init__(params, defaults) @@ -292,6 +300,9 @@ class ScaledAdam(BatchedOptimizer): size_update_period, *param_rms.shape, **kwargs ) + if group["p_limit_values"] > 0: + state["stored_percentiles"] = torch.ones_like(param_rms) + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) @@ -598,7 +609,12 @@ class ScaledAdam(BatchedOptimizer): alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] - delta.add_(grad * alpha) + + if random.random() >= group["p_limit_values"]: + delta.add_(grad * alpha) + else: + delta.add_((grad * alpha) * self._limit_values_sign(group, p, grad, state)) + p.add_(delta) def _step_scalar(self, group: dict, p: Tensor, state: dict): @@ -625,6 +641,67 @@ class ScaledAdam(BatchedOptimizer): p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) + def _limit_values_sign(self, group: dict, p: Tensor, grad: Tensor, state: dict): + """Decide whether to modify the sign of the cerrent update. + + For each parameter tensor in p, we store the centain percentile (e.g., 95%), + say stored_percentiles. + If more than 5% parameter absolute values are larger than stored_percentiles: + Increase stored_percentiles by (1.0 + lr / p_limit_values). + Else: + Decrease stored_percentiles by (1.0 - lr / p_limit_values). + For all parameters whose absolute values are larger than stored_percentiles, + modify the sign of the update so that it points 'inward'. + + Args: + group: A dict which will be used to look up configuration values + p: The parameter to be updated + grad: The grad of p + state: The state-dict corresponding to parameter p + + Returns: A tensor with same shape as p, filled with 1 or -1. + """ + lr = group["lr"] + # The probability to limit the absolute values + p_limit_values = group["p_limit_values"] # e.g., 0.1 + # The parameter absolute values over 1-percentile_limit percentile will be limited. + percentile_limit = group["percentile_limit"] # e.g., 0.05 + # It stores the percentiles (i.e, 1-percentile_limit) of the parameter absolute values, + # with a shape like (batch_size, 1, 1, 1, 1) + stored_percentiles = state["stored_percentiles"] + + batch_size = p.shape[0] + numel = p.numel() // batch_size + p_abs = p.abs() + + p_exceed = p_abs > stored_percentiles # same shape as p + # The proportion that exceeds stored_percentiles + proportion_exceed = ( + p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel + ) + # Update store_percentiles + update_sign = (proportion_exceed - percentile_limit).sign() + stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_( + min=1.0e-20 + ) + + # For these parameters that exceed stored_percentile, + # flip the update sign if they would get larger absolute values + limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0) * (p_abs > 10) + + # TODO: will remove the log bellow + if random.random() < 0.1: + mask = limit_sign == -1 + if mask.any(): + logging.info(f"p.shape: {p.shape}") + logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}") + logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}") + logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}") + proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel + logging.info(f"proportion_sign_change: {proportion_sign_change}") + + return limit_sign + class LRScheduler(object): """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 436ec53b4..05c2bcae1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -376,6 +376,22 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--p-limit-values", + type=float, + default=0.0, + help="""The probability (e.g., 0.1) to modify the update sign, so as to limit the + parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.""", + ) + + parser.add_argument( + "--percentile-limit", + type=float, + default=0.05, + help="""The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile + will be limited.""", + ) + add_model_arguments(parser) return parser @@ -1009,6 +1025,8 @@ def run(rank, world_size, args): lr=params.base_lr, clipping_scale=2.0, parameters_names=parameters_names, + p_limit_values=params.p_limit_values, + percentile_limit=params.percentile_limit, ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)