mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Bug fixes to avoid inf alpha
This commit is contained in:
parent
0787583580
commit
7aa47408af
@ -720,7 +720,7 @@ class Cain(Optimizer):
|
|||||||
|
|
||||||
We will be returning:
|
We will be returning:
|
||||||
|
|
||||||
denom = alpha * (exp_avg_sq) + beta * (exp_avg_sq.sqrt() + eps)
|
denom = alpha * (exp_avg_sq) + beta * exp_avg_sq.sqrt() + eps
|
||||||
|
|
||||||
and are aiming for alpha to satisfy the following condition:
|
and are aiming for alpha to satisfy the following condition:
|
||||||
|
|
||||||
@ -735,7 +735,7 @@ class Cain(Optimizer):
|
|||||||
|
|
||||||
The way we update alpha is as follows. We always have the previous iteration's
|
The way we update alpha is as follows. We always have the previous iteration's
|
||||||
alpha, call this alpha_prev. We compute:
|
alpha, call this alpha_prev. We compute:
|
||||||
denom_prev = (alpha_prev * exp_avg_sq) + beta * (exp_avg_sq.sqrt() + eps)
|
denom_prev = (alpha_prev * exp_avg_sq) + beta * exp_avg_sq.sqrt() + eps
|
||||||
and:
|
and:
|
||||||
mean_prev = (exp_avg_sq.sqrt() + eps) / denom).mean()
|
mean_prev = (exp_avg_sq.sqrt() + eps) / denom).mean()
|
||||||
.. now, mean_prev varies *at most* like 1/alpha, meaning its maximum sensitivity
|
.. now, mean_prev varies *at most* like 1/alpha, meaning its maximum sensitivity
|
||||||
@ -751,9 +751,9 @@ class Cain(Optimizer):
|
|||||||
for _ in range(num_steps):
|
for _ in range(num_steps):
|
||||||
# our update for alpha is iterative since there isn't a closed-form solution.
|
# our update for alpha is iterative since there isn't a closed-form solution.
|
||||||
alpha = state["alpha"]
|
alpha = state["alpha"]
|
||||||
adam_denom = exp_avg_sq.sqrt().add_(eps)
|
adam_denom = exp_avg_sq.sqrt()
|
||||||
denom = alpha * exp_avg_sq + beta * adam_denom
|
denom = alpha * exp_avg_sq + beta * adam_denom + eps
|
||||||
mean_speed = (adam_denom / denom).mean()
|
mean_speed = ((adam_denom + eps) / denom).mean()
|
||||||
alpha.fill_(alpha * mean_speed)
|
alpha.fill_(alpha * mean_speed)
|
||||||
return denom
|
return denom
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user