diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 838200deb..8a51edf32 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -118,7 +118,7 @@ class Cain(Optimizer): params, lr=1e-3, ng_pow=-0.5, - ng_eps=0.05, + ng_eps=0.1, betas=(0.9, 0.98), eps=1e-8, target_rms=0.1, @@ -281,10 +281,10 @@ class Cain(Optimizer): grad_rms = (exp_avg_sq.sqrt()).add_(eps).mul_(1.0 / bias_correction2) this_delta = grad / grad_rms - - if p.ndim() > 1: - smoothed_grad_rms = ng_eps * grad_rms.mean() + grad_rms - lr_factor = (smoothed_grad_rms ** ng_pow) + if p.ndim > 1: + grad_rms_mean = grad_rms.mean() + smoothed_grad_rms = ng_eps * grad_rms_mean + grad_rms + lr_factor = (smoothed_grad_rms ** ng_pow) + ng_eps * (grad_rms_mean ** ng_pow) lr_factor_mean = lr_factor.mean() # this `lr` contains all the terms in learning rate except param_rms. Imagine # that param_rms is 1.0 for purposes of computing the shrinkage coefficient; actually