Fix bias_correction2, power of 0.5

This commit is contained in:
Daniel Povey 2022-06-17 20:20:16 +08:00
parent 2d13b682fe
commit cbccc1dd91

View File

@ -270,7 +270,7 @@ class NeutralGradient(Optimizer):
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
grad_rms = (exp_avg_sq.sqrt()).add_(grad_eps) grad_rms = (exp_avg_sq.sqrt()).add_(grad_eps)
this_delta = grad / grad_rms this_delta = grad / grad_rms
alpha = (-(1-beta1) * lr * bias_correction2) * scale alpha = (-(1-beta1) * lr * (bias_correction2 ** 0.5)) * scale
delta.add_(this_delta, alpha=alpha) delta.add_(this_delta, alpha=alpha)
else: else:
# The full update. # The full update.
@ -296,7 +296,7 @@ class NeutralGradient(Optimizer):
cur_grad = cur_grad / (exp_avg_sq.sqrt() + grad_eps) cur_grad = cur_grad / (exp_avg_sq.sqrt() + grad_eps)
if bias_correction2 < 0.99: if bias_correction2 < 0.99:
cur_grad *= bias_correction2 cur_grad *= (bias_correction2 ** 0.5)
cur_grad = self._change_coordinates(cur_grad, state, forward=False) cur_grad = self._change_coordinates(cur_grad, state, forward=False)