Code cleanup

This commit is contained in:
Daniel Povey 2022-06-03 11:54:12 +08:00
parent d6e65a0e7f
commit 3fff0c75bb

View File

@ -84,13 +84,6 @@ class Cain(Optimizer):
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
max_eff_lr (float, optional): maximum effective learning rate for
natural-gradient update; this limits how aggressively we accelerate
directions with small gradients. The idea is that you might want to
leave this constant as the regular learning rate decreasess; but
we can investigate alternatives here. If you set this to <= 0,
it disables the natural-gradient aspect, giving a more ADAM-like
update (in changed co-ordinates)
ng_scale (float, optional): scale on natural-gradient-like update,
interpolating it with an Adam-type update.
betas (Tuple[float, float], optional): coefficients used for computing
@ -124,7 +117,6 @@ class Cain(Optimizer):
self,
params,
lr=1e-3,
max_eff_lr=3e-3,
ng_pow=-0.5,
ng_eps=0.05,
betas=(0.9, 0.98),
@ -159,7 +151,6 @@ class Cain(Optimizer):
defaults = dict(
lr=lr,
max_eff_lr=max_eff_lr,
ng_eps=ng_eps,
ng_pow=ng_pow,
betas=betas,
@ -189,7 +180,6 @@ class Cain(Optimizer):
for group in self.param_groups:
beta1, beta2 = group["betas"]
lr = group["lr"]
max_eff_lr = group["max_eff_lr"]
ng_eps = group["ng_eps"]
ng_pow = group["ng_pow"]
target_rms = group["target_rms"]
@ -197,8 +187,6 @@ class Cain(Optimizer):
eps = group["eps"]
rms_max = group["rms_max"]
assert lr <= max_eff_lr
for p in group["params"]:
if p.grad is None:
continue
@ -231,10 +219,6 @@ class Cain(Optimizer):
# that also emulates Eve.
state["scale_exp_avg_sq"] = torch.zeros((), device=p.device,
dtype=p.dtype)
# alpha is a scale related to the natural-gradient update,
# see self._get_grad_scale()
state["alpha"] = torch.ones((), device=p.device,
dtype=p.dtype)
@ -339,92 +323,6 @@ class Cain(Optimizer):
return loss
def _get_grad_scale(self,
state: dict,
step: int,
exp_avg_sq: Tensor,
bias_correction2: float,
eps: float,
beta: float,
ng_scale: float):
"""
This function returns a "gradient-scale" value that we'll multiply the gradients by
before multiplying by -lr and applying momentum.
Args:
state: state-dict for this parameter, needed for its "alpha" member
step: the current step index
exp_avg_sq: Moving-average of gradient-squared for each element of the current
parameter matrix (possibly after a co-ordinate change).
bias_correction2: A value that will be normally 1.0, but will be less than 1.0
near the start of training or just after we re-estimated the co-ordinate
changes; we can divide exp_avg_sq by this to get an expected gradient-squared.
eps: Epsilon value that prevents division by zero; dimensionally, this is
a gradient magnitude (not squared-gradient magnitude).
beta: A value <= 1.0, e.g. 0.1; it equals lr / max_eff_lr, and represents
the inverse of the "maximum acceleration" we allow the natural gradient to
give, versus regular ADAM. Think of 1/beta as the maximum normalized
gradient root-mean-square value we allow.
We'll explain what this function does by comparing (i) basic ADAM-like updat,e
(ii) basic natural-gradient update, (iii) our update which combines properties
of both.
The basic ADAM-like version of this would be:
denom = (exp_avg_sq.sqrt()).add_(eps),
and ans=1/denom; so after we divide the gradients by this, they will have unit expected
squared value.
The natural-gradient version of this would be:
denom = (exp_avg_sq).add_(eps * eps)
base_denom = (exp_avg_sq.sqrt()).add_(eps),
denom = denom * (base_denom / denom).mean()
what this does is: compute 'denom' with no sqrt(), and then scale it in such
a way that ([root-mean-square gradient magnitude] / denom).mean() == 1.0, so
the overall learning speed is about the same as the baseline.
We will be returning:
denom = alpha * (exp_avg_sq) + beta * exp_avg_sq.sqrt() + eps
and are aiming for alpha to satisfy the following condition:
((exp_avg_sq.sqrt() + eps) / denom).mean() == 1.0 (eqn:1)
... (eqn:1) means that the overall learning rate / "rate of progress"
is about the same as if denom was just (exp_avg_sq.sqrt() + eps). Also,
the beta term in "denom" ensures that:
(1/denom) <= 1/beta * (1/adam_denom)
which ensures that the speed of update for any given element will never
be greater than the normal ADAM-type update by more than 1/beta.
The way we update alpha is as follows. We always have the previous iteration's
alpha, call this alpha_prev. We compute:
denom_prev = (alpha_prev * exp_avg_sq) + beta * exp_avg_sq.sqrt() + eps
and:
mean_prev = (exp_avg_sq.sqrt() + eps) / denom).mean()
.. now, mean_prev varies *at most* like 1/alpha, meaning its maximum sensitivity
to alpha is an inverse relationship. So we we update alpha as:
alpha_cur = alpha_prev * mean_prev
(the idea is that this will eventually converge; and meanwhile, we'll just have
a slightly inaccurate alpha). If `step` is large, we'll assume alpha has converged
and always return the denominator with the old alpha.
In the end, we interpolate the natural gradient update with the regular ADAM-like
one, as in: grad_scale = ng_scale/denom + (1-ng_scale)/adam_denom
"""
num_steps = 10 if step < 3 else 2 if step < 100 else 1
for _ in range(num_steps):
# our update for alpha is iterative since there isn't a closed-form solution.
alpha = state["alpha"]
adam_denom = exp_avg_sq.sqrt()
adam_denom_eps = adam_denom + eps
denom = alpha * exp_avg_sq + beta * adam_denom + eps
mean_speed = ((adam_denom_eps) / denom).mean()
alpha.fill_(alpha * mean_speed)
return ng_scale / denom + (1-ng_scale) / adam_denom_eps
def _change_coordinates(self,