From cbfd459df7803e3fb2e6ae1fb1c2d1ea78b9ab35 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sat, 18 Mar 2023 21:47:08 +0800 Subject: [PATCH] add _limit_values_sign in ScaledAdam --- .../ASR/pruned_transducer_stateless7/optim.py | 64 ++++++++++++++++++- .../ASR/pruned_transducer_stateless7/train.py | 10 +++ 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 374b78cb3..209780bd7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -16,6 +16,7 @@ import contextlib import logging +import math import random from collections import defaultdict from typing import List, Optional, Tuple, Union @@ -163,6 +164,11 @@ class ScaledAdam(BatchedOptimizer): of the parameter tensor. This is provided to save a little time in the update. clipping_update_period: if clipping_scale is specified, this is the period + p_limit_values: The probability (e.g., 0.1) to modify the update sign so as to prevent + absolute-values of any weight tensor from being over a certain percentile of + the distribution of that parameter tensor's absolute values. + percentile_limit: The percentile (e.g., 0.9) over which the parameter absolute values would be + limited. """ def __init__( @@ -180,6 +186,8 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period=100, parameters_names=None, show_dominant_parameters=True, + p_limit_values=0.0, + percentile_limit=0.9, ): assert parameters_names is not None, ( @@ -198,6 +206,8 @@ class ScaledAdam(BatchedOptimizer): scalar_max=scalar_max, size_update_period=size_update_period, clipping_update_period=clipping_update_period, + p_limit_values=p_limit_values, + percentile_limit=percentile_limit, ) super(ScaledAdam, self).__init__(params, defaults) @@ -296,6 +306,9 @@ class ScaledAdam(BatchedOptimizer): size_update_period, *param_rms.shape, **kwargs ) + if group["p_limit_values"] > 0: + state["stored_percentiles"] = torch.ones_like(param_rms) + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) @@ -603,7 +616,12 @@ class ScaledAdam(BatchedOptimizer): alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] - delta.add_(grad * alpha) + + if random.random() >= group["p_limit_values"]: + delta.add_(grad * alpha) + else: + delta.add_((grad * alpha) * self._limit_values_sign(group, p, grad, state)) + p.add_(delta) def _step_scalar(self, group: dict, p: Tensor, state: dict): @@ -630,6 +648,50 @@ class ScaledAdam(BatchedOptimizer): p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) + def _limit_values_sign(self, group: dict, p: Tensor, grad: Tensor, state: dict): + """Decide whether to modify the sign of the update. + + Args: + group: A dict which will be used to look up configuration values + p: The parameter to be updated + grad: The grad of p + state: The state-dict corresponding to parameter p + + Returns: A tensor with same shape as p, filled with 1 or -1. + """ + lr = group["lr"] + p_limit_values = group["p_limit_values"] # e.g., 0.1 + percentile_limit = group["percentile_limit"] # e.g., 0.9 + # it has a shape like (batch_size, 1, 1, 1, 1) + stored_percentiles = state["stored_percentiles"] + + p_abs = p.abs() + dtype = p.dtype + batch_size = p.shape[0] + + numel = p.numel() / batch_size + k = math.ceil(numel * (1 - percentile_limit)) + percentiles = p_abs.view(batch_size, -1).topk(k=k, dim=-1)[0][:, -1] # (batch,) + + # If True, stored_percentiles should be increased + percentiles_exceed = percentiles.view(stored_percentiles.shape) > stored_percentiles + + # Update store_percentiles + update_sign = (percentiles_exceed.to(dtype) - 0.5).sign() + stored_percentiles.mul_(1 + update_sign * lr / p_limit_values).clamp_(min=1.0e-20) + + p_exceed = p_abs > stored_percentiles + # if random.random() < 0.1: + # # print(stored_percentiles) + # # print(percentiles_exceed) + # print(p_exceed.sum(dim=list(range(1, p.ndim))) / numel) + + # Decide whether to change grad sign + limit_sign = (~percentiles_exceed * p_exceed) * ((p.sign() * grad.sign()) < 0) + limit_sign = (limit_sign.to(dtype) - 0.5).sign() + + return -1 * limit_sign + class LRScheduler(object): """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 792a243e5..2c4d009ae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -374,6 +374,15 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--p-limit-values", + type=float, + default=0.0, + help="""The probability (e.g., 0.1) to modify the update sign so as to prevent + absolute-values of any weight tensor from being over a certain percentile of + the distribution of that parameter tensor's absolute values""", + ) + add_model_arguments(parser) return parser @@ -1016,6 +1025,7 @@ def run(rank, world_size, args): lr=params.base_lr, clipping_scale=2.0, parameters_names=parameters_names, + p_limit_values=params.p_limit_values, ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)