diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py index 0289db2f1..9aaf356cf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py @@ -334,6 +334,10 @@ class Abel(Optimizer): # to standard Adam. denom = (exp_avg_sq/bias_correction2 + eps*eps).sqrt() this_delta = grad / denom + + # eventually will remove this special case, this is to deal + # with our model that has scaling factors. + lr_eff = lr * (bias_correction2 ** 0.5) else: # update that includes factorization-related stuff.. factors_exp_avg_sq, factorization = state["factors_exp_avg_sq"], state["factorization"] @@ -396,8 +400,8 @@ class Abel(Optimizer): # and x_factor2: again, before taking into account the learning # rate or momentum. this_delta = ((grad * factorization / denom) + p * factors_sum) + lr_eff = lr * (bias_correction2 ** 0.5) / target_rms - lr_eff = lr * (bias_correction2 ** 0.5) / target_rms delta.mul_(beta1).add_(this_delta, alpha=-lr_eff*(1-beta1)) p.add_(delta)