From 679972b9050bdf630842c8cc4906833bbb2948df Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 3 Jun 2022 19:37:48 +0800 Subject: [PATCH] Fix bug; make epsilon work both ways (small+large); increase epsilon to 0.1 --- .../ASR/pruned_transducer_stateless7/optim.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 838200deb..8a51edf32 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -118,7 +118,7 @@ class Cain(Optimizer): params, lr=1e-3, ng_pow=-0.5, - ng_eps=0.05, + ng_eps=0.1, betas=(0.9, 0.98), eps=1e-8, target_rms=0.1, @@ -281,10 +281,10 @@ class Cain(Optimizer): grad_rms = (exp_avg_sq.sqrt()).add_(eps).mul_(1.0 / bias_correction2) this_delta = grad / grad_rms - - if p.ndim() > 1: - smoothed_grad_rms = ng_eps * grad_rms.mean() + grad_rms - lr_factor = (smoothed_grad_rms ** ng_pow) + if p.ndim > 1: + grad_rms_mean = grad_rms.mean() + smoothed_grad_rms = ng_eps * grad_rms_mean + grad_rms + lr_factor = (smoothed_grad_rms ** ng_pow) + ng_eps * (grad_rms_mean ** ng_pow) lr_factor_mean = lr_factor.mean() # this `lr` contains all the terms in learning rate except param_rms. Imagine # that param_rms is 1.0 for purposes of computing the shrinkage coefficient; actually