From e2fdfe990ccc47fdbe695d6b844c09995779d5fc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 20 Sep 2022 15:20:43 +0800 Subject: [PATCH] Loosen limit on param_max_rms, from 2.0 to 3.0; change how param_min_rms is applied. --- .../ASR/pruned_transducer_stateless7/optim.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 257312f9a..77cf6d9db 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -78,7 +78,7 @@ class ScaledAdam(Optimizer): scalar_lr_scale=0.1, eps=1.0e-08, param_min_rms=1.0e-05, - param_max_rms=2.0, + param_max_rms=3.0, scalar_max=2.0, size_update_period=4, clipping_update_period=100, @@ -272,7 +272,6 @@ class ScaledAdam(Optimizer): """ lr = group["lr"] size_update_period = group["size_update_period"] - eps = group["eps"] beta1 = group["betas"][0] grad = p.grad @@ -289,7 +288,7 @@ class ScaledAdam(Optimizer): scale_grads[step % size_update_period] = (p * grad).sum() if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] - param_rms.copy_((p**2).mean().sqrt().clamp_(min=eps)) + param_rms.copy_((p**2).mean().sqrt()) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. @@ -330,6 +329,7 @@ class ScaledAdam(Optimizer): size_lr = group["lr"] * group["scalar_lr_scale"] param_min_rms = group["param_min_rms"] param_max_rms = group["param_max_rms"] + eps = group["eps"] step = state["step"] size_update_period = scale_grads.shape[0] @@ -348,7 +348,6 @@ class ScaledAdam(Optimizer): # we don't bother with bias_correction1; this will help prevent divergence # at the start of training. - eps = 1.0e-10 # this value is not critical. denom = scale_exp_avg_sq.sqrt() + eps scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum() / denom @@ -356,7 +355,9 @@ class ScaledAdam(Optimizer): is_too_small = (param_rms < param_min_rms) is_too_large = (param_rms > param_max_rms) - scale_step.masked_fill_(is_too_small, size_lr * size_update_period) + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + # when it gets too large, stop it from getting any larger. scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) delta = state["delta"] # the factor of (1-beta1) relates to momentum. @@ -383,8 +384,10 @@ class ScaledAdam(Optimizer): lr = group["lr"] beta1, beta2 = group["betas"] eps = group["eps"] + param_min_rms = group["param_min_rms"] step = state["step"] + exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1-beta2)) @@ -399,8 +402,7 @@ class ScaledAdam(Optimizer): denom += eps grad = grad / denom - - alpha = -lr * (1-beta1) * state["param_rms"] + alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] delta.add_(grad * alpha)