From cbfd459df7803e3fb2e6ae1fb1c2d1ea78b9ab35 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sat, 18 Mar 2023 21:47:08 +0800 Subject: [PATCH 1/4] add _limit_values_sign in ScaledAdam --- .../ASR/pruned_transducer_stateless7/optim.py | 64 ++++++++++++++++++- .../ASR/pruned_transducer_stateless7/train.py | 10 +++ 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 374b78cb3..209780bd7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -16,6 +16,7 @@ import contextlib import logging +import math import random from collections import defaultdict from typing import List, Optional, Tuple, Union @@ -163,6 +164,11 @@ class ScaledAdam(BatchedOptimizer): of the parameter tensor. This is provided to save a little time in the update. clipping_update_period: if clipping_scale is specified, this is the period + p_limit_values: The probability (e.g., 0.1) to modify the update sign so as to prevent + absolute-values of any weight tensor from being over a certain percentile of + the distribution of that parameter tensor's absolute values. + percentile_limit: The percentile (e.g., 0.9) over which the parameter absolute values would be + limited. """ def __init__( @@ -180,6 +186,8 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period=100, parameters_names=None, show_dominant_parameters=True, + p_limit_values=0.0, + percentile_limit=0.9, ): assert parameters_names is not None, ( @@ -198,6 +206,8 @@ class ScaledAdam(BatchedOptimizer): scalar_max=scalar_max, size_update_period=size_update_period, clipping_update_period=clipping_update_period, + p_limit_values=p_limit_values, + percentile_limit=percentile_limit, ) super(ScaledAdam, self).__init__(params, defaults) @@ -296,6 +306,9 @@ class ScaledAdam(BatchedOptimizer): size_update_period, *param_rms.shape, **kwargs ) + if group["p_limit_values"] > 0: + state["stored_percentiles"] = torch.ones_like(param_rms) + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) @@ -603,7 +616,12 @@ class ScaledAdam(BatchedOptimizer): alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] - delta.add_(grad * alpha) + + if random.random() >= group["p_limit_values"]: + delta.add_(grad * alpha) + else: + delta.add_((grad * alpha) * self._limit_values_sign(group, p, grad, state)) + p.add_(delta) def _step_scalar(self, group: dict, p: Tensor, state: dict): @@ -630,6 +648,50 @@ class ScaledAdam(BatchedOptimizer): p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) + def _limit_values_sign(self, group: dict, p: Tensor, grad: Tensor, state: dict): + """Decide whether to modify the sign of the update. + + Args: + group: A dict which will be used to look up configuration values + p: The parameter to be updated + grad: The grad of p + state: The state-dict corresponding to parameter p + + Returns: A tensor with same shape as p, filled with 1 or -1. + """ + lr = group["lr"] + p_limit_values = group["p_limit_values"] # e.g., 0.1 + percentile_limit = group["percentile_limit"] # e.g., 0.9 + # it has a shape like (batch_size, 1, 1, 1, 1) + stored_percentiles = state["stored_percentiles"] + + p_abs = p.abs() + dtype = p.dtype + batch_size = p.shape[0] + + numel = p.numel() / batch_size + k = math.ceil(numel * (1 - percentile_limit)) + percentiles = p_abs.view(batch_size, -1).topk(k=k, dim=-1)[0][:, -1] # (batch,) + + # If True, stored_percentiles should be increased + percentiles_exceed = percentiles.view(stored_percentiles.shape) > stored_percentiles + + # Update store_percentiles + update_sign = (percentiles_exceed.to(dtype) - 0.5).sign() + stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(min=1.0e-20) + + p_exceed = p_abs > stored_percentiles + # if random.random() < 0.1: + # # print(stored_percentiles) + # # print(percentiles_exceed) + # print(p_exceed.sum(dim=list(range(1, p.ndim))) / numel) + + # Decide whether to change grad sign + limit_sign = (~percentiles_exceed * p_exceed) * ((p.sign() * grad.sign()) < 0) + limit_sign = (limit_sign.to(dtype) - 0.5).sign() + + return -1 * limit_sign + class LRScheduler(object): """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 792a243e5..2c4d009ae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -374,6 +374,15 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--p-limit-values", + type=float, + default=0.0, + help="""The probability (e.g., 0.1) to modify the update sign so as to prevent + absolute-values of any weight tensor from being over a certain percentile of + the distribution of that parameter tensor's absolute values""", + ) + add_model_arguments(parser) return parser @@ -1016,6 +1025,7 @@ def run(rank, world_size, args): lr=params.base_lr, clipping_scale=2.0, parameters_names=parameters_names, + p_limit_values=params.p_limit_values, ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) From 8c91b9d0cd2ff84d265b04da5d77f4d40c58624f Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 22 Mar 2023 19:31:10 +0800 Subject: [PATCH 2/4] correct errors in _limit_values_sign --- .../ASR/pruned_transducer_stateless7/optim.py | 68 +++++++++++-------- .../ASR/pruned_transducer_stateless7/train.py | 14 +++- 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 209780bd7..4c9010124 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -164,11 +164,9 @@ class ScaledAdam(BatchedOptimizer): of the parameter tensor. This is provided to save a little time in the update. clipping_update_period: if clipping_scale is specified, this is the period - p_limit_values: The probability (e.g., 0.1) to modify the update sign so as to prevent - absolute-values of any weight tensor from being over a certain percentile of - the distribution of that parameter tensor's absolute values. - percentile_limit: The percentile (e.g., 0.9) over which the parameter absolute values would be - limited. + percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited. + p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the + parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile. """ def __init__( @@ -186,8 +184,8 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period=100, parameters_names=None, show_dominant_parameters=True, + percentile_limit=0.05, p_limit_values=0.0, - percentile_limit=0.9, ): assert parameters_names is not None, ( @@ -206,8 +204,8 @@ class ScaledAdam(BatchedOptimizer): scalar_max=scalar_max, size_update_period=size_update_period, clipping_update_period=clipping_update_period, - p_limit_values=p_limit_values, percentile_limit=percentile_limit, + p_limit_values=p_limit_values, ) super(ScaledAdam, self).__init__(params, defaults) @@ -649,7 +647,16 @@ class ScaledAdam(BatchedOptimizer): p.add_(delta) def _limit_values_sign(self, group: dict, p: Tensor, grad: Tensor, state: dict): - """Decide whether to modify the sign of the update. + """Decide whether to modify the sign of the cerrent update. + + For each parameter tensor in p, we store the centain percentile (e.g., 95%), + say stored_percentiles. + If more than 5% parameter absolute values are larger than stored_percentiles: + Increase stored_percentiles by (1.0 + lr / p_limit_values). + Else: + Decrease stored_percentiles by (1.0 - lr / p_limit_values). + For all parameters whose absolute values are larger than stored_percentiles, + modify the sign of the update so that it points 'inward'. Args: group: A dict which will be used to look up configuration values @@ -660,37 +667,40 @@ class ScaledAdam(BatchedOptimizer): Returns: A tensor with same shape as p, filled with 1 or -1. """ lr = group["lr"] + # The probability to limit the absolute values p_limit_values = group["p_limit_values"] # e.g., 0.1 - percentile_limit = group["percentile_limit"] # e.g., 0.9 - # it has a shape like (batch_size, 1, 1, 1, 1) + # The parameter absolute values over 1-percentile_limit percentile will be limited. + percentile_limit = group["percentile_limit"] # e.g., 0.05 + # It stores the percentiles (i.e, 1-percentile_limit) of the parameter absolute values, + # with a shape like (batch_size, 1, 1, 1, 1) stored_percentiles = state["stored_percentiles"] - p_abs = p.abs() - dtype = p.dtype batch_size = p.shape[0] + numel = p.numel() // batch_size + p_abs = p.abs() - numel = p.numel() / batch_size - k = math.ceil(numel * (1 - percentile_limit)) - percentiles = p_abs.view(batch_size, -1).topk(k=k, dim=-1)[0][:, -1] # (batch,) - - # If True, stored_percentiles should be increased - percentiles_exceed = percentiles.view(stored_percentiles.shape) > stored_percentiles - + p_exceed = p_abs > stored_percentiles # same shape as p + # The proportion that exceeds stored_percentiles + proportion_exceed = p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel # Update store_percentiles - update_sign = (percentiles_exceed.to(dtype) - 0.5).sign() + update_sign = (proportion_exceed - percentile_limit).sign() stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(min=1.0e-20) - p_exceed = p_abs > stored_percentiles - # if random.random() < 0.1: - # # print(stored_percentiles) - # # print(percentiles_exceed) - # print(p_exceed.sum(dim=list(range(1, p.ndim))) / numel) + # For these parameters that exceed stored_percentile, + # flip the update sign if they would get larger absolute values + limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0) - # Decide whether to change grad sign - limit_sign = (~percentiles_exceed * p_exceed) * ((p.sign() * grad.sign()) < 0) - limit_sign = (limit_sign.to(dtype) - 0.5).sign() + # TODO: will remove the log bellow + if random.random() < 0.1: + logging.info(f"p.shape: {p.shape}") + logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}") + logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}") + logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}") + mask = limit_sign == -1 + proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel + logging.info(f"proportion_sign_change: {proportion_sign_change}") - return -1 * limit_sign + return limit_sign class LRScheduler(object): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 2c4d009ae..9f3b010a2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -378,9 +378,16 @@ def get_parser(): "--p-limit-values", type=float, default=0.0, - help="""The probability (e.g., 0.1) to modify the update sign so as to prevent - absolute-values of any weight tensor from being over a certain percentile of - the distribution of that parameter tensor's absolute values""", + help="""The probability (e.g., 0.1) to modify the update sign, so as to limit the + parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.""", + ) + + parser.add_argument( + "--percentile-limit", + type=float, + default=0.05, + help="""The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile + will be limited.""", ) add_model_arguments(parser) @@ -1026,6 +1033,7 @@ def run(rank, world_size, args): clipping_scale=2.0, parameters_names=parameters_names, p_limit_values=params.p_limit_values, + percentile_limit=params.percentile_limit, ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) From df4b615993bea4bedca377765f58c14560c161e3 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 23 Mar 2023 17:27:57 +0800 Subject: [PATCH 3/4] Fix for style_check --- .../ASR/pruned_transducer_stateless7/optim.py | 88 ++++++++++--------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 4c9010124..330e23abf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -125,48 +125,48 @@ class BatchedOptimizer(Optimizer): class ScaledAdam(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period - percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited. - p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the - parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile. + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period + percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited. + p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the + parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile. """ def __init__( @@ -681,10 +681,14 @@ class ScaledAdam(BatchedOptimizer): p_exceed = p_abs > stored_percentiles # same shape as p # The proportion that exceeds stored_percentiles - proportion_exceed = p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel + proportion_exceed = ( + p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel + ) # Update store_percentiles update_sign = (proportion_exceed - percentile_limit).sign() - stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(min=1.0e-20) + stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_( + min=1.0e-20 + ) # For these parameters that exceed stored_percentile, # flip the update sign if they would get larger absolute values From 5e58d2de75cd619cf2b4eb7e1c1e3a51f49d377f Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 23 Mar 2023 20:06:03 +0800 Subject: [PATCH 4/4] add condition of p_abs > 10 --- .../ASR/pruned_transducer_stateless7/optim.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 4c9010124..f2fa6e85a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -688,17 +688,18 @@ class ScaledAdam(BatchedOptimizer): # For these parameters that exceed stored_percentile, # flip the update sign if they would get larger absolute values - limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0) + limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0) * (p_abs > 10) # TODO: will remove the log bellow if random.random() < 0.1: - logging.info(f"p.shape: {p.shape}") - logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}") - logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}") - logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}") mask = limit_sign == -1 - proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel - logging.info(f"proportion_sign_change: {proportion_sign_change}") + if mask.any(): + logging.info(f"p.shape: {p.shape}") + logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}") + logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}") + logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}") + proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel + logging.info(f"proportion_sign_change: {proportion_sign_change}") return limit_sign