mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Restore missing factor 1-beta1
This commit is contained in:
parent
768c260a4d
commit
6f974b32f6
@ -603,7 +603,7 @@ class Cain(Optimizer):
|
|||||||
|
|
||||||
scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"])
|
scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"])
|
||||||
|
|
||||||
scale_alpha = -lr*(scale_bias_correction2 ** 0.5)
|
scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * (1-beta1)
|
||||||
|
|
||||||
# scale_delta is going to be the change in log-scale (before momentum), i.e. if
|
# scale_delta is going to be the change in log-scale (before momentum), i.e. if
|
||||||
# p = underlying_param * scale.exp(),
|
# p = underlying_param * scale.exp(),
|
||||||
@ -617,7 +617,8 @@ class Cain(Optimizer):
|
|||||||
# specified limit, start to decay the parameters (and ignore the computed
|
# specified limit, start to decay the parameters (and ignore the computed
|
||||||
# delta).
|
# delta).
|
||||||
scale_delta = torch.where(state["param_rms"] > rms_max,
|
scale_delta = torch.where(state["param_rms"] > rms_max,
|
||||||
torch.tensor(-1.0e-03, device=scale_delta.device,
|
torch.tensor(-1.0e-03 * (1-beta1),
|
||||||
|
device=scale_delta.device,
|
||||||
dtype=scale_delta.dtype),
|
dtype=scale_delta.dtype),
|
||||||
scale_delta)
|
scale_delta)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user