Bug fix for scalar case

This commit is contained in:
Daniel Povey 2022-06-17 12:16:12 +08:00
parent 9e92d13a33
commit 47df144253

View File

@ -329,7 +329,7 @@ class NeutralGradient(Optimizer):
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
bias_correction2 = 1 - beta2 ** (step + 1)
denom = (exp_avg_sq.sqrt()).add_(eps)
denom = (exp_avg_sq.sqrt()).add_(grad_eps)
this_delta = grad / denom
alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5)
delta.add_(this_delta, alpha=alpha)