From 8c91b9d0cd2ff84d265b04da5d77f4d40c58624f Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 22 Mar 2023 19:31:10 +0800 Subject: [PATCH] correct errors in _limit_values_sign --- .../ASR/pruned_transducer_stateless7/optim.py | 68 +++++++++++-------- .../ASR/pruned_transducer_stateless7/train.py | 14 +++- 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 209780bd7..4c9010124 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -164,11 +164,9 @@ class ScaledAdam(BatchedOptimizer): of the parameter tensor. This is provided to save a little time in the update. clipping_update_period: if clipping_scale is specified, this is the period - p_limit_values: The probability (e.g., 0.1) to modify the update sign so as to prevent - absolute-values of any weight tensor from being over a certain percentile of - the distribution of that parameter tensor's absolute values. - percentile_limit: The percentile (e.g., 0.9) over which the parameter absolute values would be - limited. + percentile_limit: The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile will be limited. + p_limit_values: The probability (e.g., 0.1) to modify the update sign, so as to limit the + parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile. """ def __init__( @@ -186,8 +184,8 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period=100, parameters_names=None, show_dominant_parameters=True, + percentile_limit=0.05, p_limit_values=0.0, - percentile_limit=0.9, ): assert parameters_names is not None, ( @@ -206,8 +204,8 @@ class ScaledAdam(BatchedOptimizer): scalar_max=scalar_max, size_update_period=size_update_period, clipping_update_period=clipping_update_period, - p_limit_values=p_limit_values, percentile_limit=percentile_limit, + p_limit_values=p_limit_values, ) super(ScaledAdam, self).__init__(params, defaults) @@ -649,7 +647,16 @@ class ScaledAdam(BatchedOptimizer): p.add_(delta) def _limit_values_sign(self, group: dict, p: Tensor, grad: Tensor, state: dict): - """Decide whether to modify the sign of the update. + """Decide whether to modify the sign of the cerrent update. + + For each parameter tensor in p, we store the centain percentile (e.g., 95%), + say stored_percentiles. + If more than 5% parameter absolute values are larger than stored_percentiles: + Increase stored_percentiles by (1.0 + lr / p_limit_values). + Else: + Decrease stored_percentiles by (1.0 - lr / p_limit_values). + For all parameters whose absolute values are larger than stored_percentiles, + modify the sign of the update so that it points 'inward'. Args: group: A dict which will be used to look up configuration values @@ -660,37 +667,40 @@ class ScaledAdam(BatchedOptimizer): Returns: A tensor with same shape as p, filled with 1 or -1. """ lr = group["lr"] + # The probability to limit the absolute values p_limit_values = group["p_limit_values"] # e.g., 0.1 - percentile_limit = group["percentile_limit"] # e.g., 0.9 - # it has a shape like (batch_size, 1, 1, 1, 1) + # The parameter absolute values over 1-percentile_limit percentile will be limited. + percentile_limit = group["percentile_limit"] # e.g., 0.05 + # It stores the percentiles (i.e, 1-percentile_limit) of the parameter absolute values, + # with a shape like (batch_size, 1, 1, 1, 1) stored_percentiles = state["stored_percentiles"] - p_abs = p.abs() - dtype = p.dtype batch_size = p.shape[0] + numel = p.numel() // batch_size + p_abs = p.abs() - numel = p.numel() / batch_size - k = math.ceil(numel * (1 - percentile_limit)) - percentiles = p_abs.view(batch_size, -1).topk(k=k, dim=-1)[0][:, -1] # (batch,) - - # If True, stored_percentiles should be increased - percentiles_exceed = percentiles.view(stored_percentiles.shape) > stored_percentiles - + p_exceed = p_abs > stored_percentiles # same shape as p + # The proportion that exceeds stored_percentiles + proportion_exceed = p_exceed.sum(dim=list(range(1, p.ndim)), keepdim=True) / numel # Update store_percentiles - update_sign = (percentiles_exceed.to(dtype) - 0.5).sign() + update_sign = (proportion_exceed - percentile_limit).sign() stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(min=1.0e-20) - p_exceed = p_abs > stored_percentiles - # if random.random() < 0.1: - # # print(stored_percentiles) - # # print(percentiles_exceed) - # print(p_exceed.sum(dim=list(range(1, p.ndim))) / numel) + # For these parameters that exceed stored_percentile, + # flip the update sign if they would get larger absolute values + limit_sign = 1 - 2 * p_exceed * ((p.sign() * grad.sign()) < 0) - # Decide whether to change grad sign - limit_sign = (~percentiles_exceed * p_exceed) * ((p.sign() * grad.sign()) < 0) - limit_sign = (limit_sign.to(dtype) - 0.5).sign() + # TODO: will remove the log bellow + if random.random() < 0.1: + logging.info(f"p.shape: {p.shape}") + logging.info(f"stored_percentiles: {stored_percentiles.view(-1)}") + logging.info(f"proportion_exceed: {proportion_exceed.view(-1)}") + logging.info(f"p_abs_max: {p_abs.view(batch_size, -1).max(dim=-1)[0]}") + mask = limit_sign == -1 + proportion_sign_change = mask.sum(dim=list(range(1, p.ndim))) / numel + logging.info(f"proportion_sign_change: {proportion_sign_change}") - return -1 * limit_sign + return limit_sign class LRScheduler(object): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 2c4d009ae..9f3b010a2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -378,9 +378,16 @@ def get_parser(): "--p-limit-values", type=float, default=0.0, - help="""The probability (e.g., 0.1) to modify the update sign so as to prevent - absolute-values of any weight tensor from being over a certain percentile of - the distribution of that parameter tensor's absolute values""", + help="""The probability (e.g., 0.1) to modify the update sign, so as to limit the + parameter absolute values that are larger than 1-percentile_limit (e.g., 95%) percentile.""", + ) + + parser.add_argument( + "--percentile-limit", + type=float, + default=0.05, + help="""The parameter absolute values over 1-percentile_limit (e.g., 95%) percentile + will be limited.""", ) add_model_arguments(parser) @@ -1026,6 +1033,7 @@ def run(rank, world_size, args): clipping_scale=2.0, parameters_names=parameters_names, p_limit_values=params.p_limit_values, + percentile_limit=params.percentile_limit, ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)