Slightly simplify scaling code

This commit is contained in:
Daniel Povey 2022-05-20 16:43:05 +08:00
parent 4257c1024f
commit 768c260a4d

View File

@ -582,12 +582,13 @@ class Cain(Optimizer):
# that also emulates Eve.
state["scale_exp_avg_sq"] = torch.zeros((), device=p.device,
dtype=p.dtype)
state["scale_exp_avg"] = torch.zeros((), device=p.device,
dtype=p.dtype)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
step = state["step"]
exp_avg.mul_(beta1)
if p.numel() > 1:
# This part learns the scale on the parameters.
# scale_deriv is the derivative w.r.t. a log-space scale on the parameters, i.e.
@ -595,21 +596,19 @@ class Cain(Optimizer):
# w.r.t. scale (it does not actually matter what value `scale` currently has).
scale_deriv = (p * grad).sum()
scale_exp_avg_sq = state["scale_exp_avg_sq"]
scale_exp_avg = state["scale_exp_avg"]
scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv,
value=1 - beta2)
scale_exp_avg.mul_(beta1).add_(scale_deriv, alpha=(1-beta1))
scale_bias_correction2 = 1 - beta2 ** step
scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"])
scale_alpha = -lr*(scale_bias_correction2 ** 0.5)
# scale_delta is going to be the change in log-scale, i.e. if
# scale_delta is going to be the change in log-scale (before momentum), i.e. if
# p = underlying_param * scale.exp(),
# delta is the change in `scale`.
scale_delta = scale_alpha * (scale_exp_avg / scale_denom)
scale_delta = scale_alpha * (scale_deriv / scale_denom)
# the following is equivalent to:
# if torch.all(state["param_rms"] > rms_max):
@ -622,9 +621,8 @@ class Cain(Optimizer):
dtype=scale_delta.dtype),
scale_delta)
p.mul_(1 + scale_delta)
exp_avg.add_(p, alpha=scale_delta)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
# forget bias_correction1. We use normal momentum. exp_avg really
# just stores the moving-average gradient step.
@ -640,7 +638,6 @@ class Cain(Optimizer):
this_delta = grad / denom
this_delta = self._change_coordinates(this_delta, state, forward=False)
alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5)
if p.numel() > 1:
if step % 10 == 0:
@ -651,8 +648,7 @@ class Cain(Optimizer):
# treat/repurpose exp_avg as moving-average of `this_delta`
# don't use alpha=alpha as I don't think
exp_avg.mul_(beta1).add_(this_delta, alpha=alpha)
exp_avg.add_(this_delta, alpha=alpha)
p.add_(exp_avg)
state["step"] += 1