mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Bug fixes to avoid inf alpha
This commit is contained in:
parent
0787583580
commit
7aa47408af
@ -720,7 +720,7 @@ class Cain(Optimizer):
|
||||
|
||||
We will be returning:
|
||||
|
||||
denom = alpha * (exp_avg_sq) + beta * (exp_avg_sq.sqrt() + eps)
|
||||
denom = alpha * (exp_avg_sq) + beta * exp_avg_sq.sqrt() + eps
|
||||
|
||||
and are aiming for alpha to satisfy the following condition:
|
||||
|
||||
@ -735,7 +735,7 @@ class Cain(Optimizer):
|
||||
|
||||
The way we update alpha is as follows. We always have the previous iteration's
|
||||
alpha, call this alpha_prev. We compute:
|
||||
denom_prev = (alpha_prev * exp_avg_sq) + beta * (exp_avg_sq.sqrt() + eps)
|
||||
denom_prev = (alpha_prev * exp_avg_sq) + beta * exp_avg_sq.sqrt() + eps
|
||||
and:
|
||||
mean_prev = (exp_avg_sq.sqrt() + eps) / denom).mean()
|
||||
.. now, mean_prev varies *at most* like 1/alpha, meaning its maximum sensitivity
|
||||
@ -751,9 +751,9 @@ class Cain(Optimizer):
|
||||
for _ in range(num_steps):
|
||||
# our update for alpha is iterative since there isn't a closed-form solution.
|
||||
alpha = state["alpha"]
|
||||
adam_denom = exp_avg_sq.sqrt().add_(eps)
|
||||
denom = alpha * exp_avg_sq + beta * adam_denom
|
||||
mean_speed = (adam_denom / denom).mean()
|
||||
adam_denom = exp_avg_sq.sqrt()
|
||||
denom = alpha * exp_avg_sq + beta * adam_denom + eps
|
||||
mean_speed = ((adam_denom + eps) / denom).mean()
|
||||
alpha.fill_(alpha * mean_speed)
|
||||
return denom
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user