diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 79acd7ca3..d50198388 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -40,6 +40,14 @@ class ScaledAdam(Optimizer): lr: The learning rate. We will typically use a learning rate schedule that starts at 0.03 and decreases over time, i.e. much higher than other common optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. Must satisfy 0 < beta <= beta2 < 1. size_lr_scale: A scaling factor on the learning rate, that we use to update the @@ -57,12 +65,14 @@ class ScaledAdam(Optimizer): model has any parameters with numel() == 1). size_update_period: The periodicity, in steps, with which we update the size (scale) of the parameter tensor. This is provided to save a little time - in the update. + in the update. + clipping_update_period: if clipping_scale is specified, this is the period """ def __init__( self, params, lr=3e-02, + clipping_scale=None, betas=(0.9, 0.98), size_lr_scale=0.1, eps=1.0e-08, @@ -70,12 +80,14 @@ class ScaledAdam(Optimizer): param_max_rms=2.0, scalar_max=2.0, size_update_period=4, + clipping_update_period=100, ): defaults = dict( lr=lr, + clipping_scale=clipping_scale, betas=betas, size_lr_scale=size_lr_scale, eps=eps, @@ -83,6 +95,7 @@ class ScaledAdam(Optimizer): param_max_rms=param_max_rms, scalar_max=scalar_max, size_update_period=size_update_period, + clipping_update_period=clipping_update_period, ) super(ScaledAdam, self).__init__(params, defaults) @@ -106,7 +119,9 @@ class ScaledAdam(Optimizer): batch = True for group in self.param_groups: - for p in group["params"]: + + + for i,p in enumerate(group["params"]): state = self.state[p] # Perform optimization step @@ -118,7 +133,11 @@ class ScaledAdam(Optimizer): # State initialization if len(state) == 0: self._init_state(group, p, state) - self._step_one_batch(group, p, state) + + if i == 0: + clipping_scale = self._get_clipping_scale(group, p, state) + + self._step_one_batch(group, p, state, clipping_scale) return loss @@ -170,11 +189,79 @@ class ScaledAdam(Optimizer): p, memory_format=torch.preserve_format ) + def _get_clipping_scale(self, + group: dict, + p: Tensor, + state: dict) -> float: + """ + Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients + by this amount before applying the rest of the update. + This function is only to be called for the first parameter in the group. + """ + clipping_scale = group["clipping_scale"] + step = state["step"] + if clipping_scale is None or step == 0: + # no clipping. return early on step == 0 because the other + # parameters' state won't have been initialize yet. + return 1.0 + clipping_update_period = group["clipping_update_period"] + + tot_sumsq = torch.tensor(0.0, device=p.device) + for p in group["params"]: + state = self.state[p] + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + if p.numel() == 1: + tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] + else: + tot_sumsq += (grad**2).sum() * (state["param_rms"] ** 2) + tot_norm = tot_sumsq.sqrt() + if not "model_norms" in state: + state["model_norms"] = torch.zeros(clipping_update_period, + device=p.device) + state["model_norms"][step % clipping_update_period] = tot_norm + + if step % clipping_update_period == 0: + # We don't reach here if step == 0 because we + # would have returned above. + sorted_norms = state["model_norms"].sort()[0].to('cpu') + quartiles = [] + for n in range(0, 5): + index = min(clipping_update_period - 1, + (clipping_update_period // 4) * n) + quartiles.append(sorted_norms[index].item()) + + median = quartiles[2] + threshold = clipping_scale * median + state["model_norm_threshold"] = threshold + percent_clipped = (state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in state else 0.0) + state["num_clipped"] = 0 + quartiles = [ '%.3e' % x for x in quartiles ] + logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") + + if step < clipping_update_period: + return 1.0 # We have not yet estimated a norm to clip to. + else: + model_norm_threshold = state["model_norm_threshold"] + ans = min(1.0, + (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans < 1.0: + state["num_clipped"] += 1 + if ans < 0.1: + logging.warn("Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") + return ans + def _step_one_batch(self, group: dict, p: Tensor, - state: dict): + state: dict, + clipping_scale: float): """ Do the step for parameter p. Args: @@ -188,6 +275,8 @@ class ScaledAdam(Optimizer): beta1 = group["betas"][0] grad = p.grad + if clipping_scale != 1.0: + grad = grad * clipping_scale step = state["step"] delta = state["delta"] @@ -692,7 +781,7 @@ def _test_scaled_adam(hidden_dim: int): torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] if iter == 0: optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03) + elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)