diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py index 57bda02ff..8a1789d45 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py @@ -720,7 +720,7 @@ class Cain(Optimizer): We will be returning: - denom = alpha * (exp_avg_sq) + beta * (exp_avg_sq.sqrt() + eps) + denom = alpha * (exp_avg_sq) + beta * exp_avg_sq.sqrt() + eps and are aiming for alpha to satisfy the following condition: @@ -735,7 +735,7 @@ class Cain(Optimizer): The way we update alpha is as follows. We always have the previous iteration's alpha, call this alpha_prev. We compute: - denom_prev = (alpha_prev * exp_avg_sq) + beta * (exp_avg_sq.sqrt() + eps) + denom_prev = (alpha_prev * exp_avg_sq) + beta * exp_avg_sq.sqrt() + eps and: mean_prev = (exp_avg_sq.sqrt() + eps) / denom).mean() .. now, mean_prev varies *at most* like 1/alpha, meaning its maximum sensitivity @@ -751,9 +751,9 @@ class Cain(Optimizer): for _ in range(num_steps): # our update for alpha is iterative since there isn't a closed-form solution. alpha = state["alpha"] - adam_denom = exp_avg_sq.sqrt().add_(eps) - denom = alpha * exp_avg_sq + beta * adam_denom - mean_speed = (adam_denom / denom).mean() + adam_denom = exp_avg_sq.sqrt() + denom = alpha * exp_avg_sq + beta * adam_denom + eps + mean_speed = ((adam_denom + eps) / denom).mean() alpha.fill_(alpha * mean_speed) return denom