mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Fix bug; make epsilon work both ways (small+large); increase epsilon to 0.1
This commit is contained in:
parent
8085ed6ef9
commit
679972b905
@ -118,7 +118,7 @@ class Cain(Optimizer):
|
||||
params,
|
||||
lr=1e-3,
|
||||
ng_pow=-0.5,
|
||||
ng_eps=0.05,
|
||||
ng_eps=0.1,
|
||||
betas=(0.9, 0.98),
|
||||
eps=1e-8,
|
||||
target_rms=0.1,
|
||||
@ -281,10 +281,10 @@ class Cain(Optimizer):
|
||||
grad_rms = (exp_avg_sq.sqrt()).add_(eps).mul_(1.0 / bias_correction2)
|
||||
this_delta = grad / grad_rms
|
||||
|
||||
|
||||
if p.ndim() > 1:
|
||||
smoothed_grad_rms = ng_eps * grad_rms.mean() + grad_rms
|
||||
lr_factor = (smoothed_grad_rms ** ng_pow)
|
||||
if p.ndim > 1:
|
||||
grad_rms_mean = grad_rms.mean()
|
||||
smoothed_grad_rms = ng_eps * grad_rms_mean + grad_rms
|
||||
lr_factor = (smoothed_grad_rms ** ng_pow) + ng_eps * (grad_rms_mean ** ng_pow)
|
||||
lr_factor_mean = lr_factor.mean()
|
||||
# this `lr` contains all the terms in learning rate except param_rms. Imagine
|
||||
# that param_rms is 1.0 for purposes of computing the shrinkage coefficient; actually
|
||||
|
Loading…
x
Reference in New Issue
Block a user