Restore missing factor 1-beta1

This commit is contained in:
Daniel Povey 2022-05-20 17:43:48 +08:00
parent 768c260a4d
commit 6f974b32f6

View File

@ -603,7 +603,7 @@ class Cain(Optimizer):
scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"]) scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"])
scale_alpha = -lr*(scale_bias_correction2 ** 0.5) scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * (1-beta1)
# scale_delta is going to be the change in log-scale (before momentum), i.e. if # scale_delta is going to be the change in log-scale (before momentum), i.e. if
# p = underlying_param * scale.exp(), # p = underlying_param * scale.exp(),
@ -617,7 +617,8 @@ class Cain(Optimizer):
# specified limit, start to decay the parameters (and ignore the computed # specified limit, start to decay the parameters (and ignore the computed
# delta). # delta).
scale_delta = torch.where(state["param_rms"] > rms_max, scale_delta = torch.where(state["param_rms"] > rms_max,
torch.tensor(-1.0e-03, device=scale_delta.device, torch.tensor(-1.0e-03 * (1-beta1),
device=scale_delta.device,
dtype=scale_delta.dtype), dtype=scale_delta.dtype),
scale_delta) scale_delta)