mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Slightly simplify scaling code
This commit is contained in:
parent
4257c1024f
commit
768c260a4d
@ -582,12 +582,13 @@ class Cain(Optimizer):
|
||||
# that also emulates Eve.
|
||||
state["scale_exp_avg_sq"] = torch.zeros((), device=p.device,
|
||||
dtype=p.dtype)
|
||||
state["scale_exp_avg"] = torch.zeros((), device=p.device,
|
||||
dtype=p.dtype)
|
||||
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
step = state["step"]
|
||||
|
||||
exp_avg.mul_(beta1)
|
||||
|
||||
if p.numel() > 1:
|
||||
# This part learns the scale on the parameters.
|
||||
# scale_deriv is the derivative w.r.t. a log-space scale on the parameters, i.e.
|
||||
@ -595,21 +596,19 @@ class Cain(Optimizer):
|
||||
# w.r.t. scale (it does not actually matter what value `scale` currently has).
|
||||
scale_deriv = (p * grad).sum()
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
||||
scale_exp_avg = state["scale_exp_avg"]
|
||||
scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv,
|
||||
value=1 - beta2)
|
||||
|
||||
scale_exp_avg.mul_(beta1).add_(scale_deriv, alpha=(1-beta1))
|
||||
scale_bias_correction2 = 1 - beta2 ** step
|
||||
|
||||
scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"])
|
||||
|
||||
scale_alpha = -lr*(scale_bias_correction2 ** 0.5)
|
||||
|
||||
# scale_delta is going to be the change in log-scale, i.e. if
|
||||
# scale_delta is going to be the change in log-scale (before momentum), i.e. if
|
||||
# p = underlying_param * scale.exp(),
|
||||
# delta is the change in `scale`.
|
||||
scale_delta = scale_alpha * (scale_exp_avg / scale_denom)
|
||||
scale_delta = scale_alpha * (scale_deriv / scale_denom)
|
||||
|
||||
# the following is equivalent to:
|
||||
# if torch.all(state["param_rms"] > rms_max):
|
||||
@ -622,9 +621,8 @@ class Cain(Optimizer):
|
||||
dtype=scale_delta.dtype),
|
||||
scale_delta)
|
||||
|
||||
p.mul_(1 + scale_delta)
|
||||
exp_avg.add_(p, alpha=scale_delta)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
|
||||
# forget bias_correction1. We use normal momentum. exp_avg really
|
||||
# just stores the moving-average gradient step.
|
||||
@ -640,7 +638,6 @@ class Cain(Optimizer):
|
||||
this_delta = grad / denom
|
||||
this_delta = self._change_coordinates(this_delta, state, forward=False)
|
||||
|
||||
|
||||
alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5)
|
||||
if p.numel() > 1:
|
||||
if step % 10 == 0:
|
||||
@ -651,8 +648,7 @@ class Cain(Optimizer):
|
||||
|
||||
# treat/repurpose exp_avg as moving-average of `this_delta`
|
||||
# don't use alpha=alpha as I don't think
|
||||
exp_avg.mul_(beta1).add_(this_delta, alpha=alpha)
|
||||
|
||||
exp_avg.add_(this_delta, alpha=alpha)
|
||||
|
||||
p.add_(exp_avg)
|
||||
state["step"] += 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user