From 768c260a4d391b6ba2b8f673ff460426e6cdd851 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 20 May 2022 16:43:05 +0800 Subject: [PATCH] Slightly simplify scaling code --- .../ASR/pruned_transducer_stateless4/optim.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py index 921df6c63..2fe84cf21 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py @@ -582,12 +582,13 @@ class Cain(Optimizer): # that also emulates Eve. state["scale_exp_avg_sq"] = torch.zeros((), device=p.device, dtype=p.dtype) - state["scale_exp_avg"] = torch.zeros((), device=p.device, - dtype=p.dtype) + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] step = state["step"] + exp_avg.mul_(beta1) + if p.numel() > 1: # This part learns the scale on the parameters. # scale_deriv is the derivative w.r.t. a log-space scale on the parameters, i.e. @@ -595,21 +596,19 @@ class Cain(Optimizer): # w.r.t. scale (it does not actually matter what value `scale` currently has). scale_deriv = (p * grad).sum() scale_exp_avg_sq = state["scale_exp_avg_sq"] - scale_exp_avg = state["scale_exp_avg"] scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv, value=1 - beta2) - scale_exp_avg.mul_(beta1).add_(scale_deriv, alpha=(1-beta1)) scale_bias_correction2 = 1 - beta2 ** step scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"]) scale_alpha = -lr*(scale_bias_correction2 ** 0.5) - # scale_delta is going to be the change in log-scale, i.e. if + # scale_delta is going to be the change in log-scale (before momentum), i.e. if # p = underlying_param * scale.exp(), # delta is the change in `scale`. - scale_delta = scale_alpha * (scale_exp_avg / scale_denom) + scale_delta = scale_alpha * (scale_deriv / scale_denom) # the following is equivalent to: # if torch.all(state["param_rms"] > rms_max): @@ -622,9 +621,8 @@ class Cain(Optimizer): dtype=scale_delta.dtype), scale_delta) - p.mul_(1 + scale_delta) + exp_avg.add_(p, alpha=scale_delta) - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] # forget bias_correction1. We use normal momentum. exp_avg really # just stores the moving-average gradient step. @@ -640,7 +638,6 @@ class Cain(Optimizer): this_delta = grad / denom this_delta = self._change_coordinates(this_delta, state, forward=False) - alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5) if p.numel() > 1: if step % 10 == 0: @@ -651,8 +648,7 @@ class Cain(Optimizer): # treat/repurpose exp_avg as moving-average of `this_delta` # don't use alpha=alpha as I don't think - exp_avg.mul_(beta1).add_(this_delta, alpha=alpha) - + exp_avg.add_(this_delta, alpha=alpha) p.add_(exp_avg) state["step"] += 1