mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Restore missing factor 1-beta1
This commit is contained in:
parent
768c260a4d
commit
6f974b32f6
@ -603,7 +603,7 @@ class Cain(Optimizer):
|
||||
|
||||
scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"])
|
||||
|
||||
scale_alpha = -lr*(scale_bias_correction2 ** 0.5)
|
||||
scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * (1-beta1)
|
||||
|
||||
# scale_delta is going to be the change in log-scale (before momentum), i.e. if
|
||||
# p = underlying_param * scale.exp(),
|
||||
@ -617,7 +617,8 @@ class Cain(Optimizer):
|
||||
# specified limit, start to decay the parameters (and ignore the computed
|
||||
# delta).
|
||||
scale_delta = torch.where(state["param_rms"] > rms_max,
|
||||
torch.tensor(-1.0e-03, device=scale_delta.device,
|
||||
torch.tensor(-1.0e-03 * (1-beta1),
|
||||
device=scale_delta.device,
|
||||
dtype=scale_delta.dtype),
|
||||
scale_delta)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user