mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix bias_correction2, power of 0.5
This commit is contained in:
parent
2d13b682fe
commit
cbccc1dd91
@ -270,7 +270,7 @@ class NeutralGradient(Optimizer):
|
|||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||||
grad_rms = (exp_avg_sq.sqrt()).add_(grad_eps)
|
grad_rms = (exp_avg_sq.sqrt()).add_(grad_eps)
|
||||||
this_delta = grad / grad_rms
|
this_delta = grad / grad_rms
|
||||||
alpha = (-(1-beta1) * lr * bias_correction2) * scale
|
alpha = (-(1-beta1) * lr * (bias_correction2 ** 0.5)) * scale
|
||||||
delta.add_(this_delta, alpha=alpha)
|
delta.add_(this_delta, alpha=alpha)
|
||||||
else:
|
else:
|
||||||
# The full update.
|
# The full update.
|
||||||
@ -296,7 +296,7 @@ class NeutralGradient(Optimizer):
|
|||||||
|
|
||||||
cur_grad = cur_grad / (exp_avg_sq.sqrt() + grad_eps)
|
cur_grad = cur_grad / (exp_avg_sq.sqrt() + grad_eps)
|
||||||
if bias_correction2 < 0.99:
|
if bias_correction2 < 0.99:
|
||||||
cur_grad *= bias_correction2
|
cur_grad *= (bias_correction2 ** 0.5)
|
||||||
|
|
||||||
cur_grad = self._change_coordinates(cur_grad, state, forward=False)
|
cur_grad = self._change_coordinates(cur_grad, state, forward=False)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user