From 6306f244303304319827ebf0a213e8d718592d35 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 15 May 2022 14:12:23 +0800 Subject: [PATCH] Some changes to algorithm; more diagnostics printing --- .../pruned_transducer_stateless4b/optim.py | 55 +++++++++++-------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py index 2a4845694..f01ca24e6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py @@ -214,7 +214,8 @@ class Abel(Optimizer): eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-08). target_rms (float, optional): target root-mean-square value of - x_norm in factorization (conceptually). + x_norm in factorization (conceptually). Actually this now just becomes + a factor in the learning rate, we may remove it at some point. .. _Adam\: A Method for Stochastic Optimization: @@ -245,7 +246,7 @@ class Abel(Optimizer): raise ValueError( "Invalid beta parameter at index 1: {}".format(betas[1]) ) - if not 0 < target_rms <= 10.0: + if not 0 < target_rms <= 1.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( lr=lr, @@ -314,7 +315,8 @@ class Abel(Optimizer): torch.zeros(*shape, dtype=p.dtype, device=p.device) for shape in _get_factor_shapes(list(p.shape)) ] - state["factorization"] = _init_factorization(p, eps) + # we call _init_factorization inside step(). + state["factorization"] = torch.zeros_like(p) exp_avg_sq = state["exp_avg_sq"] delta, exp_avg_sq = state["delta"], state["exp_avg_sq"] @@ -350,15 +352,32 @@ class Abel(Optimizer): factor_grads = _get_factor_grads(p, grad) if step < 10 or step % 10 == 1: - # do this only every 10 steps, to save time. + # update the factorization this only every 10 steps, to save time. + if step % 1000 == 0: + # Periodically refresh the factorization; this is + # out of concern that after a large number of + # multiplications, roundoff could cause it to drift + # significantly away from being a product of + # independent factors, one per dimension. + factorization[:] = _init_factorization(p, eps) _update_factorization(p, factorization, speed=0.1, eps=eps) + factors_sum = None + numel = p.numel() for g, e in zip(factor_grads, factors_exp_avg_sq): + this_numel = numel / g.numel() update_exp_avg_sq(g, e) - this_denom = (e/bias_correction2 + eps*eps).sqrt() + # this_numel**0.5 will turn into this_numel**-0.25 when we sqrt() and + # divide by denom. This becomes a factor in the learning rate. + # The technique naturally makes us learn faster by (this_numel**0.5) + # in any of the parameter directions corresponding to rows or columns, + # and because we don't want to be to aggressive and risk instability or + # excess param noise, we reduce this to (this_numel**0.25) by dividing + # the effective LR by this_numel**0.25. + this_denom = (e*(this_numel**0.5)/bias_correction2 + eps*eps).sqrt() assert g.shape == this_denom.shape factor_delta = g / this_denom factors_sum = (factor_delta if factors_sum is None @@ -371,18 +390,7 @@ class Abel(Optimizer): # from updating the factors (before considering momentum!) - # We multiply by target_rms*target_rms here to have the same - # effect as multiplying by target_rms after the sqrt, so that later - # when we divide by denom it's like dividing the learning rate by - # target_rms. This is to simulate the effect of keeping `factorization` - # updated in such a way that x_norm had an average root-mean-square - # value of `target_rms` instead of 1.0; this would require - # `factorization` being multiplied of `1/target_rms` below when setting - # `delta`. We do it this way to save a tensor operation. - # We are very lazy about when to add `eps`, just doing it whenever is - # most convenient, since this is really not going to be critical. - - denom = (exp_avg_sq*target_rms*target_rms/bias_correction2 + + denom = (exp_avg_sq/bias_correction2 + eps*eps).sqrt() # `grad * factorization / denom` represents the change in parameters x, only considering @@ -398,9 +406,8 @@ class Abel(Optimizer): # rate or momentum. this_delta = ((grad * factorization / denom) + p * factors_sum) - - # compute the moving-average change in parameters, and add it to p. - delta.mul_(beta1).add_(this_delta, alpha=-lr*(1-beta1)) + lr_eff = lr * (bias_correction2 ** 0.5) / target_rms + delta.mul_(beta1).add_(this_delta, alpha=-lr_eff*(1-beta1)) p.add_(delta) @@ -749,7 +756,7 @@ def _test_abel(): # Abel is working as we expect and is able to adjust scales of # different dims differently. input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - output_magnitudes = (0.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() for iter in [0,1]: Linear = torch.nn.Linear if iter == 0 else ScaledLinear @@ -800,8 +807,10 @@ def _test_abel(): print("Normalized input col magnitudes log-stddev: ", stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log())) print("Un-normalized 0-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log())) - print("Un-normalized 2-output col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log())) - print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log())) + print("Un-normalized 2-input col magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=0).sqrt().log())) + print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=1).sqrt().log())) + + print("Normalized output row magnitudes log-stddev: ", stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log()))