mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 00:24:19 +00:00
Slightly simplify scaling code
This commit is contained in:
parent
4257c1024f
commit
768c260a4d
@ -582,12 +582,13 @@ class Cain(Optimizer):
|
|||||||
# that also emulates Eve.
|
# that also emulates Eve.
|
||||||
state["scale_exp_avg_sq"] = torch.zeros((), device=p.device,
|
state["scale_exp_avg_sq"] = torch.zeros((), device=p.device,
|
||||||
dtype=p.dtype)
|
dtype=p.dtype)
|
||||||
state["scale_exp_avg"] = torch.zeros((), device=p.device,
|
|
||||||
dtype=p.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||||
step = state["step"]
|
step = state["step"]
|
||||||
|
|
||||||
|
exp_avg.mul_(beta1)
|
||||||
|
|
||||||
if p.numel() > 1:
|
if p.numel() > 1:
|
||||||
# This part learns the scale on the parameters.
|
# This part learns the scale on the parameters.
|
||||||
# scale_deriv is the derivative w.r.t. a log-space scale on the parameters, i.e.
|
# scale_deriv is the derivative w.r.t. a log-space scale on the parameters, i.e.
|
||||||
@ -595,21 +596,19 @@ class Cain(Optimizer):
|
|||||||
# w.r.t. scale (it does not actually matter what value `scale` currently has).
|
# w.r.t. scale (it does not actually matter what value `scale` currently has).
|
||||||
scale_deriv = (p * grad).sum()
|
scale_deriv = (p * grad).sum()
|
||||||
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
||||||
scale_exp_avg = state["scale_exp_avg"]
|
|
||||||
scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv,
|
scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv,
|
||||||
value=1 - beta2)
|
value=1 - beta2)
|
||||||
|
|
||||||
scale_exp_avg.mul_(beta1).add_(scale_deriv, alpha=(1-beta1))
|
|
||||||
scale_bias_correction2 = 1 - beta2 ** step
|
scale_bias_correction2 = 1 - beta2 ** step
|
||||||
|
|
||||||
scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"])
|
scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"])
|
||||||
|
|
||||||
scale_alpha = -lr*(scale_bias_correction2 ** 0.5)
|
scale_alpha = -lr*(scale_bias_correction2 ** 0.5)
|
||||||
|
|
||||||
# scale_delta is going to be the change in log-scale, i.e. if
|
# scale_delta is going to be the change in log-scale (before momentum), i.e. if
|
||||||
# p = underlying_param * scale.exp(),
|
# p = underlying_param * scale.exp(),
|
||||||
# delta is the change in `scale`.
|
# delta is the change in `scale`.
|
||||||
scale_delta = scale_alpha * (scale_exp_avg / scale_denom)
|
scale_delta = scale_alpha * (scale_deriv / scale_denom)
|
||||||
|
|
||||||
# the following is equivalent to:
|
# the following is equivalent to:
|
||||||
# if torch.all(state["param_rms"] > rms_max):
|
# if torch.all(state["param_rms"] > rms_max):
|
||||||
@ -622,9 +621,8 @@ class Cain(Optimizer):
|
|||||||
dtype=scale_delta.dtype),
|
dtype=scale_delta.dtype),
|
||||||
scale_delta)
|
scale_delta)
|
||||||
|
|
||||||
p.mul_(1 + scale_delta)
|
exp_avg.add_(p, alpha=scale_delta)
|
||||||
|
|
||||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
|
||||||
|
|
||||||
# forget bias_correction1. We use normal momentum. exp_avg really
|
# forget bias_correction1. We use normal momentum. exp_avg really
|
||||||
# just stores the moving-average gradient step.
|
# just stores the moving-average gradient step.
|
||||||
@ -640,7 +638,6 @@ class Cain(Optimizer):
|
|||||||
this_delta = grad / denom
|
this_delta = grad / denom
|
||||||
this_delta = self._change_coordinates(this_delta, state, forward=False)
|
this_delta = self._change_coordinates(this_delta, state, forward=False)
|
||||||
|
|
||||||
|
|
||||||
alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5)
|
alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5)
|
||||||
if p.numel() > 1:
|
if p.numel() > 1:
|
||||||
if step % 10 == 0:
|
if step % 10 == 0:
|
||||||
@ -651,8 +648,7 @@ class Cain(Optimizer):
|
|||||||
|
|
||||||
# treat/repurpose exp_avg as moving-average of `this_delta`
|
# treat/repurpose exp_avg as moving-average of `this_delta`
|
||||||
# don't use alpha=alpha as I don't think
|
# don't use alpha=alpha as I don't think
|
||||||
exp_avg.mul_(beta1).add_(this_delta, alpha=alpha)
|
exp_avg.add_(this_delta, alpha=alpha)
|
||||||
|
|
||||||
|
|
||||||
p.add_(exp_avg)
|
p.add_(exp_avg)
|
||||||
state["step"] += 1
|
state["step"] += 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user