Bug fixes to avoid inf alpha

This commit is contained in:
Daniel Povey 2022-05-27 21:41:05 +08:00
parent 0787583580
commit 7aa47408af

View File

@ -720,7 +720,7 @@ class Cain(Optimizer):
We will be returning:
denom = alpha * (exp_avg_sq) + beta * (exp_avg_sq.sqrt() + eps)
denom = alpha * (exp_avg_sq) + beta * exp_avg_sq.sqrt() + eps
and are aiming for alpha to satisfy the following condition:
@ -735,7 +735,7 @@ class Cain(Optimizer):
The way we update alpha is as follows. We always have the previous iteration's
alpha, call this alpha_prev. We compute:
denom_prev = (alpha_prev * exp_avg_sq) + beta * (exp_avg_sq.sqrt() + eps)
denom_prev = (alpha_prev * exp_avg_sq) + beta * exp_avg_sq.sqrt() + eps
and:
mean_prev = (exp_avg_sq.sqrt() + eps) / denom).mean()
.. now, mean_prev varies *at most* like 1/alpha, meaning its maximum sensitivity
@ -751,9 +751,9 @@ class Cain(Optimizer):
for _ in range(num_steps):
# our update for alpha is iterative since there isn't a closed-form solution.
alpha = state["alpha"]
adam_denom = exp_avg_sq.sqrt().add_(eps)
denom = alpha * exp_avg_sq + beta * adam_denom
mean_speed = (adam_denom / denom).mean()
adam_denom = exp_avg_sq.sqrt()
denom = alpha * exp_avg_sq + beta * adam_denom + eps
mean_speed = ((adam_denom + eps) / denom).mean()
alpha.fill_(alpha * mean_speed)
return denom