Bug fixes to avoid inf alpha

This commit is contained in:
Daniel Povey 2022-05-27 21:41:05 +08:00
parent 0787583580
commit 7aa47408af

View File

@ -720,7 +720,7 @@ class Cain(Optimizer):
We will be returning: We will be returning:
denom = alpha * (exp_avg_sq) + beta * (exp_avg_sq.sqrt() + eps) denom = alpha * (exp_avg_sq) + beta * exp_avg_sq.sqrt() + eps
and are aiming for alpha to satisfy the following condition: and are aiming for alpha to satisfy the following condition:
@ -735,7 +735,7 @@ class Cain(Optimizer):
The way we update alpha is as follows. We always have the previous iteration's The way we update alpha is as follows. We always have the previous iteration's
alpha, call this alpha_prev. We compute: alpha, call this alpha_prev. We compute:
denom_prev = (alpha_prev * exp_avg_sq) + beta * (exp_avg_sq.sqrt() + eps) denom_prev = (alpha_prev * exp_avg_sq) + beta * exp_avg_sq.sqrt() + eps
and: and:
mean_prev = (exp_avg_sq.sqrt() + eps) / denom).mean() mean_prev = (exp_avg_sq.sqrt() + eps) / denom).mean()
.. now, mean_prev varies *at most* like 1/alpha, meaning its maximum sensitivity .. now, mean_prev varies *at most* like 1/alpha, meaning its maximum sensitivity
@ -751,9 +751,9 @@ class Cain(Optimizer):
for _ in range(num_steps): for _ in range(num_steps):
# our update for alpha is iterative since there isn't a closed-form solution. # our update for alpha is iterative since there isn't a closed-form solution.
alpha = state["alpha"] alpha = state["alpha"]
adam_denom = exp_avg_sq.sqrt().add_(eps) adam_denom = exp_avg_sq.sqrt()
denom = alpha * exp_avg_sq + beta * adam_denom denom = alpha * exp_avg_sq + beta * adam_denom + eps
mean_speed = (adam_denom / denom).mean() mean_speed = ((adam_denom + eps) / denom).mean()
alpha.fill_(alpha * mean_speed) alpha.fill_(alpha * mean_speed)
return denom return denom