mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
draft, not working, will edit locally
This commit is contained in:
parent
0b811546f3
commit
d64cb1cb48
@ -24,150 +24,70 @@ from torch.optim import Optimizer
|
||||
from icefall import diagnostics # only for testing code
|
||||
import logging
|
||||
|
||||
class NeutralGradient(Optimizer):
|
||||
class LearnedGradient(Optimizer):
|
||||
"""
|
||||
Implements 'Neutral Gradient' update (a play on "natural gradient"). The aim is to
|
||||
multiply the gradient by a symmetric positive definite matrix that transforms
|
||||
the covariance of gradients to be the same as the covariance of parameters.
|
||||
(Then multiply by the learning rate). Since, at any given time, we only have
|
||||
one parameter tensor, in a sense it doesn't make sense to speak of the covariance
|
||||
matrix of parameters; but the idea is, we do this separately for every axis of
|
||||
every tensor and view the indexes into all the other axes of that tensor, as
|
||||
a collection of vectors that we can compute the covariance of. There may
|
||||
not be enough such vectors to estimate a non-singular covariance matrix; in
|
||||
any case, we smooth with a diagonal matrix. (If the dimension is too large,
|
||||
we may just use the diagonal only).
|
||||
Implements 'Learned Gradient' update. It's a learned positive definite learning-rate
|
||||
matrix for each dimension of each tensor, that we multiply the derivative by, and
|
||||
then scale the delta/param-change before the update.
|
||||
|
||||
|
||||
Args:
|
||||
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
||||
lr: The learning rate. We will typically use a learning rate schedule that starts
|
||||
at 0.03 and decreases over time, i.e. much higher than other common
|
||||
optimizers.
|
||||
betas: Momentum constants for regular momentum, and second-order (per-element) stats
|
||||
meta_lr_scale: The learning rate for the learning-rate matrix, this is a scale on
|
||||
the regular learning rate.
|
||||
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad only for
|
||||
scalar tensors. Must satisfy 0 < beta <= beta2 < 1.
|
||||
scale_speed: This scales the learning rate for purposes of learning the overall scale
|
||||
of each tensor
|
||||
eps: An epsilon to prevent division by zero
|
||||
param_eps: An epsilon on the rms value of the parameter, such that when the parameter
|
||||
gets smaller than this we'll start using a fixed learning rate, not
|
||||
decreasing with the parameter norm.
|
||||
param_rel_eps: An epsilon that limits how different different parameter rms estimates can
|
||||
be within the same tensor
|
||||
param_rel_max: Limits how big we allow the parameter variance eigs to be relative to the
|
||||
original mean. Setting this to a not-very-large value like 1 ensures that
|
||||
we only apply the param scaling for smaller-than-average, not larger-than-average,
|
||||
eigenvalues, which limits the dispersion of eigenvalues of the parameter tensors.
|
||||
param_reverse_cutoff: Once the variance in a parameter direction becomes over param_reverse_cutoff times the
|
||||
average variance, we start to train slower in this direction, which should help
|
||||
avoid too much eigenvalue dispersion in the parameter matrices.
|
||||
param_max: To prevent parameter tensors getting too large, we will clip elements to
|
||||
-param_max..param_max.
|
||||
max_fullcov_size: We will only use the full-covariance (not-diagonal) update for
|
||||
dimension with size <= max_size; for those larger, we will use a diagonal
|
||||
formula
|
||||
estimate_period: Every `estimate_period` steps, we will re-estimate projections for
|
||||
the dimensions we are not treating diagonally. Each tensor will have a separately
|
||||
randomly-chosen batch index modulo `estimate_period` to do this, to
|
||||
decrease maximum memory consumption.
|
||||
stats_steps: The number of minibatches of stats that we accumulate for purpose of basis
|
||||
changes. Must be less than estimate_period
|
||||
param_pow: A scale 0 <= param_pow <= 1.0 where 1.0 means we try to fully normalize
|
||||
the parameter scale for each dimension (i.e. try to learn with speed proportional
|
||||
to the parameter rms) and 0.0 means no such normalization at all (only a scalar
|
||||
per tensor). In any case we fully normalize the parameter scale at the
|
||||
whole-tensor level, for non-scalar tensors.
|
||||
grad_pow: A scale 0 <= grad_pow <= 1.0 where 1.0 means we fully normalize the
|
||||
gradient magnitude.
|
||||
grad_min_rand: Minimum proportion of random tensor that we add to the gradient covariance,
|
||||
to make the gradient covariance have less effect.
|
||||
lr_for_speedup: A learning rate that determines how much speedup we aim for by
|
||||
accumulating gradients over multiple batches (the update is done after
|
||||
different delays for different parameters, based on an estimate of their
|
||||
importance). The target speedup factor will be (lr_for_speedup / lr).
|
||||
speedup_first_step: The first step that we start doing the speedup.
|
||||
|
||||
speedup_recalibrate_period: The number of steps between recalibrating the speedup
|
||||
(i.e. choosing the periodicity for updating each parameter tensor)
|
||||
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
||||
learning the scale on the parameters (we'll keep it >= this size)
|
||||
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
||||
learning the scale on the parameters (we'll keep it <= this size)
|
||||
lr_grad_period: The periodicity, in steps, with which we compute the gradient w.r.t.
|
||||
the learning-rate matrix. Must be more than 1, because we delay the
|
||||
gradient when we do this, and doing this every time might contribute to
|
||||
instability.
|
||||
lr_est_period: The periodicity, in steps, with which we update the learning-rate matrix;
|
||||
must be a multiple of lr_grad_period.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=3e-02,
|
||||
betas=(0.9, 0.98, 0.95),
|
||||
meta_lr_scale=1.0,
|
||||
betas=(0.9, 0.98),
|
||||
scale_speed=0.1,
|
||||
grad_eps=1e-10,
|
||||
param_eps=1.0e-06,
|
||||
param_rel_eps=1.0e-04,
|
||||
param_rel_max=1.0,
|
||||
param_reverse_cutoff=4.0,
|
||||
param_max_rms=2.0,
|
||||
eps=1.0e-08,
|
||||
param_min_rms=1.0e-05,
|
||||
max_fullcov_size=1023,
|
||||
estimate_period=2000,
|
||||
stats_steps=200,
|
||||
param_pow=0.5,
|
||||
grad_pow=1.0,
|
||||
grad_min_rand=0.0,
|
||||
lr_for_speedup=0.03,
|
||||
speedup_recalibrate_period=100,
|
||||
speedup_first_step=500,
|
||||
param_max_rms=2.0,
|
||||
lr_grad_period=4,
|
||||
lr_est_period=20,
|
||||
max_step_scale=5.0
|
||||
):
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 0: {}".format(betas[0])
|
||||
)
|
||||
if not 0.0 < betas[1] < 1.0:
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 1: {}".format(betas[1])
|
||||
)
|
||||
if not 0 < scale_speed < 1.0:
|
||||
raise ValueError("Invalid scale_speed value: {}".format(scale_speed))
|
||||
if not 0.0 <= grad_eps < 0.1:
|
||||
raise ValueError("Invalid grad_eps value: {}".format(grad_eps))
|
||||
if not 0.0 < param_eps < 0.1:
|
||||
raise ValueError("Invalid param_eps value: {}".format(param_eps))
|
||||
if not 0.1 < param_max_rms < 1.0e+10:
|
||||
raise ValueError("Invalid param_max_rms value: {}".format(param_max_rms))
|
||||
if not 0.0 < param_min_rms < 0.5:
|
||||
raise ValueError("Invalid param_min_rms value: {}".format(param_min_rms))
|
||||
if not 0 < estimate_period:
|
||||
raise ValueError("Invalid estimate_period value: {}".format(estimate_period))
|
||||
if not 0 < stats_steps < estimate_period:
|
||||
raise ValueError("Invalid stats_steps value: {}".format(stats_steps))
|
||||
if not 0 < speedup_first_step:
|
||||
raise ValueError("Invalid speedup_first_step value: {}".format(speedup_first_step))
|
||||
if not 0 < speedup_recalibrate_period:
|
||||
raise ValueError("Invalid speedup_recalibrate_period value: {}".format(speedup_recalibrate_period))
|
||||
|
||||
# TODO: check all args
|
||||
assert 0 < beta < 1
|
||||
assert lr_est_period % lr_grad_period == 0
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
meta_lr_scale=meta_lr_scale,
|
||||
beta=beta,
|
||||
grad_beta=grad_beta,
|
||||
scale_speed=scale_speed,
|
||||
grad_eps=grad_eps,
|
||||
param_eps=param_eps,
|
||||
param_rel_eps=param_rel_eps,
|
||||
param_rel_max=param_rel_max,
|
||||
param_reverse_cutoff=param_reverse_cutoff,
|
||||
param_max_rms=param_max_rms,
|
||||
param_min_rms=param_min_rms,
|
||||
max_fullcov_size=max_fullcov_size,
|
||||
estimate_period=estimate_period,
|
||||
stats_steps=stats_steps,
|
||||
param_pow=param_pow,
|
||||
grad_pow=grad_pow,
|
||||
grad_min_rand=grad_min_rand,
|
||||
lr_for_speedup=lr_for_speedup,
|
||||
speedup_recalibrate_period=speedup_recalibrate_period,
|
||||
speedup_first_step=speedup_first_step,
|
||||
speedup_step=-speedup_first_step,
|
||||
eps=eps,
|
||||
lr_est_period=lr_est_period,
|
||||
max_step_scale=max_step_scale,
|
||||
)
|
||||
super(NeutralGradient, self).__init__(params, defaults)
|
||||
|
||||
super(LearnedGradient, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(NeutralGradient, self).__setstate__(state)
|
||||
super(LearnedGradient, self).__setstate__(state)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -185,39 +105,13 @@ class NeutralGradient(Optimizer):
|
||||
|
||||
for group in self.param_groups:
|
||||
lr = group["lr"]
|
||||
beta1, beta2, beta3 = group["betas"]
|
||||
beta1, beta2 = group["betas"]
|
||||
meta_lr_scale = group["meta_lr_scale"]
|
||||
scale_speed = group["scale_speed"]
|
||||
grad_eps = group["grad_eps"]
|
||||
param_eps = group["param_eps"]
|
||||
param_rel_eps = group["param_rel_eps"]
|
||||
param_rel_max = group["param_rel_max"]
|
||||
param_reverse_cutoff = group["param_reverse_cutoff"]
|
||||
param_max_rms = group["param_max_rms"]
|
||||
param_min_rms = group["param_min_rms"]
|
||||
max_fullcov_size = group["max_fullcov_size"]
|
||||
estimate_period = group["estimate_period"]
|
||||
stats_steps = group["stats_steps"]
|
||||
param_pow = group["param_pow"]
|
||||
grad_pow = group["grad_pow"]
|
||||
grad_min_rand = group["grad_min_rand"]
|
||||
lr_for_speedup = group["lr_for_speedup"]
|
||||
speedup_first_step = group["speedup_first_step"]
|
||||
speedup_recalibrate_period = group["speedup_recalibrate_period"]
|
||||
eps = group["eps"]
|
||||
lr_est_period = group["lr_est_period"]
|
||||
max_step_scale = group["max_step_scale"]
|
||||
|
||||
speedup_step = group["speedup_step"]
|
||||
group["speedup_step"] = speedup_step + 1
|
||||
|
||||
recalibrate = (speedup_step >= 0 and
|
||||
speedup_step % speedup_recalibrate_period == 0)
|
||||
if recalibrate:
|
||||
# param_prods will be an array of products, one per parameter tensor, equal to:
|
||||
# E[grad . step],
|
||||
# where "step" is the individual step, before momentum (i.e. computed as if beta1==0).
|
||||
# We compute these in such a way as to reflect the expectation if there is no speedup
|
||||
# (i.e. assuming we are doing the update on every step with no gradient accumulation).
|
||||
# The idea is that this product reflects the importance of each parameter matrix,
|
||||
# so that parameter matrices with larger values will need more frequent updates.
|
||||
self._recalibrate_speedup(group)
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
@ -227,121 +121,74 @@ class NeutralGradient(Optimizer):
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"NeutralGradient optimizer does not support sparse gradients"
|
||||
"LearnedGradient optimizer does not support sparse gradients"
|
||||
)
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
# For speed, we will actually take a step every
|
||||
# `step_period` times step() is called; we will store the
|
||||
# gradient in state["grad"] until state["n_cached_grads"]
|
||||
# reaches step_period, at which point we'll really do the
|
||||
# update.
|
||||
state["step_period"] = 1
|
||||
state["n_cached_grads"] = 0
|
||||
# Allocate these cached_grad members immediately even though
|
||||
# we may not need them immediately, so that any OOMs will
|
||||
# occur quickly.
|
||||
state["cached_grad"] = torch.zeros_like(
|
||||
|
||||
# moving_grad is a moving average of the raw gradient. We do the accumulation
|
||||
# on the input side because it makes it easier to get grads w.r.t. the learning rate matrices.
|
||||
state["moving_grad"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
state["step"] = 0
|
||||
# Parameter step (used to implement momentum).
|
||||
state["delta"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
# Exponential moving average of squared gradient value, all
|
||||
# tensors have this.
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
|
||||
kwargs = {'device':p.device, 'dtype':p.dtype}
|
||||
# grad_times_step is an estimate of the importance of each parameter, for purposes
|
||||
# of determining step_period period (by which we takes steps only periodically, to save
|
||||
# time)
|
||||
state["grad_times_step"] = torch.zeros((), **kwargs)
|
||||
|
||||
if p.numel() > 1:
|
||||
is_one_axis = (p.numel() in list(p.shape)) # e.g. a bias
|
||||
|
||||
|
||||
# "scale" is just a per-tensor scale on the learning rate.
|
||||
param_rms = (p**2).mean().sqrt().add_(param_min_rms)
|
||||
# "scale" is just a per-tensor scalar that determines the
|
||||
# rate of learning (in terms of the rms value of the change
|
||||
# in parameter)
|
||||
param_rms = (p**2).mean().sqrt().add_(eps)
|
||||
state["scale"] = param_rms
|
||||
|
||||
|
||||
# we learn the scale on the parameters in a way that
|
||||
# also emulates Eve. scale_exp_avg_sq is the
|
||||
# moving-average square of the gradient w.r.t. this
|
||||
# scalar scale.
|
||||
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||
|
||||
if not is_one_axis:
|
||||
# each parameter has a different random time, modulo estimate_period,
|
||||
# to re-estimate the projections. "step_within_period" will increase by
|
||||
# 1 on each step, and will be reset to 0 when it reaches estimate_period.
|
||||
state["step_within_period"] = random.randint(0,
|
||||
estimate_period-stats_steps)
|
||||
|
||||
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
|
||||
|
||||
used_scale = False
|
||||
for dim in range(p.ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1 or size == p.numel():
|
||||
# if size == p.numel(), is_one_axis == True above
|
||||
continue
|
||||
elif size > max_fullcov_size:
|
||||
# diagonal only...
|
||||
state[f"proj_{dim}"] = torch.ones(size, **kwargs)
|
||||
else:
|
||||
state[f"proj_{dim}"] = torch.eye(size, **kwargs)
|
||||
|
||||
# we will decay the param stats via beta3; they are only
|
||||
# accumulated once every `estimate_period`
|
||||
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||
state[f"lr_{dim}"] = torch.eye(size, **kwargs)
|
||||
|
||||
if not used_scale:
|
||||
param_rms = (p**2).mean().sqrt().add_(param_eps)
|
||||
state[f"proj_{dim}"] *= param_rms
|
||||
used_scale = True
|
||||
|
||||
state[f"grad_cov_count_{dim}"] = 0
|
||||
# we will also be using
|
||||
# state[f"grad_cov_{dim}"]
|
||||
# but we'll allocate this and deallocate it as needed, for individual
|
||||
# parameters, to save memory.
|
||||
|
||||
step_period = state["step_period"]
|
||||
n_cached_grads = state["n_cached_grads"]
|
||||
if step_period > 1 or n_cached_grads > 0:
|
||||
cached_grad = state["cached_grad"]
|
||||
n_cached_grads += 1
|
||||
state["n_cached_grads"] = n_cached_grads
|
||||
# the point of random_step_prob is to inject a little randomness
|
||||
# into the timing of the updates, so they don't stay synchronized.
|
||||
random_step_prob = 0.05 / step_period
|
||||
if n_cached_grads >= step_period or random.random() < random_step_prob:
|
||||
# We will continue to the code below, and do a step
|
||||
grad += cached_grad
|
||||
cached_grad.zero_()
|
||||
state["n_cached_grads"] = 0
|
||||
# grad_cov_{dim} is the covariance of gradients on this axis,
|
||||
# treating all other axes as as a batch axis. This is needed
|
||||
# only for purposes of computing the scalar factor on the
|
||||
# learning rate; and, as a result of this, also contributes something
|
||||
# to the gradient w.r.t. f"lr_{dim}", which is one reason
|
||||
# why we allocate a variable to keep track of its moving average
|
||||
# instead of just using a temporary and smoothing the scalar factor.
|
||||
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||
else:
|
||||
cached_grad += grad
|
||||
# Don't do a step this time.
|
||||
continue
|
||||
else:
|
||||
# We are doing a step on every iteration.
|
||||
n_cached_grads = 1 # will be used to scale gradients below.
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
|
||||
delta = state["delta"]
|
||||
|
||||
moving_grad = state["moving_grad"]
|
||||
step = state["step"]
|
||||
|
||||
delta.mul_(beta1)
|
||||
|
||||
if p.numel() > 1:
|
||||
is_one_axis = (p.numel() in list(p.shape)) # e.g. a bias
|
||||
if p.numel() == 1:
|
||||
# For the scalar case we just Adam.
|
||||
self._scalar_update(beta1, beta2, lr, state)
|
||||
continue
|
||||
else:
|
||||
if step % lr_est_period == 0:
|
||||
self._step_with_update(group, p, grad, state)
|
||||
else:
|
||||
self._step_no_update(group, p, grad, state)
|
||||
state["step"] = step + 1
|
||||
|
||||
|
||||
if True:
|
||||
# This block learns the scale on the parameters.
|
||||
@ -363,7 +210,6 @@ class NeutralGradient(Optimizer):
|
||||
|
||||
scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * scale_speed * (1-beta1)
|
||||
|
||||
|
||||
enforce_rms_period = 2
|
||||
scale_norm_deriv = (scale_deriv / scale_denom)
|
||||
if step % enforce_rms_period == 0:
|
||||
@ -493,10 +339,160 @@ class NeutralGradient(Optimizer):
|
||||
return loss
|
||||
|
||||
|
||||
def _scalar_update(self,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: float
|
||||
state: dict):
|
||||
assert False # TODO: implement
|
||||
|
||||
|
||||
def _step_with_update(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
state: dict):
|
||||
# meta_lr is the learning rate for the learning rate matrix.
|
||||
lr = group["lr"]
|
||||
meta_lr = lr * group["meta_lr_scale"]
|
||||
beta1 = group["betas"][0]
|
||||
eps = group["eps"]
|
||||
|
||||
moving_g = state["moving_grad"]
|
||||
moving_g.mul_(beta1)
|
||||
|
||||
ndim = p.ndim
|
||||
|
||||
# Update grad_cov
|
||||
self._store_grad_stats(grad, state)
|
||||
|
||||
for dim in range(ndim):
|
||||
if p.shape[dim] != 1:
|
||||
state[f"lr_{dim}"].requires_grad = True
|
||||
|
||||
|
||||
with torch.enable_grad():
|
||||
# set requires_grad to True on state[f"lr_{dim}"]
|
||||
for dim in range(ndim):
|
||||
if p.shape[dim] != 1:
|
||||
state[f"lr_{dim}"].requires_grad = True
|
||||
|
||||
moving_d = self._multiply_by_lr(moving_g)
|
||||
|
||||
# moving_d_rms is an rms value of moving_d computed via inner products with
|
||||
# the grad, which is a more basis-independent way of computing the rms value.
|
||||
# you have to view this as the rms times an arbitrary multiplicative factor.
|
||||
moving_d_rms = (self._multiply_by_grad_cov(moving_d) * moving_d).mean().sqrt()
|
||||
|
||||
# moving_d_norm is the parameter "delta" moving_d, renormalized via an inner
|
||||
# product with the gradient covariances (for basis independence), but still with
|
||||
# an arbitrary multiplicative constant. We don't care about this constant
|
||||
# because we are going to normalize it out anyway.
|
||||
moving_d_norm = moving_d / moving_d_rms
|
||||
|
||||
# view the parameter p as [something] - alpha * moving_d_norm, where alpha contains
|
||||
# the learning rate and other factors, but we don't care about its value
|
||||
# alpha because we're going to divide by the gradient norm anyway.
|
||||
# The (pseudo-) loss function would be (param * grad).sum(), which ignoring
|
||||
# [something], is -alpha * (moving_d_norm * grad).sum(), and ignoring
|
||||
# alpha that is -(moving_d_norm * grad).sum(), so we call this neg_loss.
|
||||
neg_loss = (grad * moving_d_norm).sum()
|
||||
|
||||
# now there will grads in the state[f"lr_{dim}"] that are the negative of loss
|
||||
# function grads, i.e. we want to go in the forward grad direction.
|
||||
neg_loss.backward()
|
||||
|
||||
|
||||
for dim in range(ndim):
|
||||
if p.shape[dim] != 1:
|
||||
self._update_lr(state[f"lr_{dim}"], meta_lr)
|
||||
|
||||
|
||||
# moving_d contains a moving average of gradients, (1-beta) * (beta + beta^2 + beta^3, etc.)
|
||||
# Suppose the rms value of each individual gradient were r. Assuming these independent,
|
||||
# the variance of the elements of moving_d would be r^2 * (1-beta)^2 * (beta^2 + beta^4 + ... )
|
||||
# ((The reason why the series in parentheses starts with beta is because on "update iters",
|
||||
# we delay the current grad term (`grad`), which would normally appear with coefficient 1,
|
||||
# until the next iter))
|
||||
# which equals r^2 * (1-beta)^2 * beta^2 / (1-beta^2), so the measured rms of moving_d would be
|
||||
# moving_d_rms = individual_d_rms * (1-beta) beta / sqrt(1-beta^2), so
|
||||
# individual_d_rms = moving_d_rms * sqrt(1-beta^2) / ((1-beta)beta)
|
||||
#... this would be the root-mean-square value of the individual deltas, if we were to proces
|
||||
# them individually (if they were independent), and this is what we want to normalize by for
|
||||
# greater independence of the beta value, i.e. so beta only affects momentum and not the
|
||||
# overall speed
|
||||
|
||||
moving_d_rms = (moving_d**2).mean().sqrt() + eps
|
||||
# e.g. if beta1==0.98, this formula gives a factcor of 9.95.
|
||||
d_rms = moving_d_rms * (((1-beta1*beta1)**0.5) / (1-beta1))
|
||||
|
||||
alpha = -lr / d_rms
|
||||
p.add_(moving_d, alpha=alpha)
|
||||
|
||||
|
||||
def _update_lr(self,
|
||||
lr_mat: Tensor,
|
||||
meta_lr: float) -> None:
|
||||
"""
|
||||
Do one step of update for learning rate matrix 'lr_mat', which we assume has already
|
||||
has its `grad` property populated with the grad of the negative loss (our special
|
||||
loss that we use to train the learning rate matrix). Also delete lr_mat.grad
|
||||
|
||||
Args:
|
||||
lr_mat: The symmetric positive-semidefinite (PSD) learning-rate matrix.
|
||||
Will have its .grad populated with the derivative of a negated
|
||||
loss function (i.e. to be maximized), that we use to optimize
|
||||
this learning-rate matrix. lr_mat.grad will be deleted and
|
||||
lr_mat will be updated by one step.
|
||||
meta_lr: The speed with which we learn lr_mat.
|
||||
"""
|
||||
grad = lr_mat.grad
|
||||
del lr_mat.grad
|
||||
lr_mat.requires_grad = False
|
||||
|
||||
# OK. suppose lr_mat = M, which is symmetric positive definite.
|
||||
# Decompose: M = A N A^T
|
||||
# where N is numerically equal to M (but a separate variable), and A is a variable numerically
|
||||
# equal to I. To make it easier to keep the result positive definite, we learn A rather than N.
|
||||
# We'll use a convention where A_grad has the same orientation as A, i.e. A_grad_ij equals the grad
|
||||
# of loss w.r.t A_ij.
|
||||
# The grad w.r.t A equals (N M_grad)^T + (M_grad N) = 2 (M_grad N). We don't care about scalar
|
||||
# factors at this point (we'll normalize them out), so ignore the factor of 2, and say;
|
||||
# A_grad = M_grad N.
|
||||
A_grad = torch.matmul(grad, lr_mat)
|
||||
|
||||
|
||||
# Normalize the size of `A_grad` (there are arbitrary scalar factors in this loss function),
|
||||
# and include meta_lr
|
||||
A_grad = A_grad * (meta_lr / ((A_grad ** 2).mean().sqrt() + 1.0e-20))
|
||||
|
||||
# So the updated A is just I plus lr_mat * normalize[A_grad]
|
||||
|
||||
size = A.shape[0]
|
||||
A = torch.eye(size, device=A.device, dtype=A.dtype) + meta_lr * A_grad
|
||||
|
||||
# the result is positive semidefinite if the initial LR_mat was positive semidefinite,
|
||||
# because it is of the form X X^T where X = A lr_mat^{0.5}
|
||||
torch.matmul(A, torch.matmul(lr_mat, A.t()), out=lr_mat)
|
||||
|
||||
# Normalize the scale oflr_matA: we always ensure lr_mat.diag().mean() is 1.0
|
||||
lr_mat *= 1.0 / lr_mat.diag().mean()
|
||||
|
||||
|
||||
|
||||
def _step_no_update(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
grad: Tensor,
|
||||
state: dict):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def _store_grad_stats(self,
|
||||
grad: Tensor,
|
||||
state: dict,
|
||||
max_fullcov_size: int) -> None:
|
||||
beta2: float) -> None:
|
||||
"""
|
||||
Accumulate some stats for each dimension of the tensor, about the covariance of gradients.
|
||||
The stats are just summed, not decayed; we will re-estimate the projections periodically,
|
||||
@ -506,23 +502,18 @@ class NeutralGradient(Optimizer):
|
||||
kwargs = {'device': grad.device, 'dtype': grad.dtype}
|
||||
for dim in range(ndim):
|
||||
size = grad.shape[dim]
|
||||
is_fullcov = (size <= max_fullcov_size)
|
||||
if size == 1 or not is_fullcov:
|
||||
if size == 1:
|
||||
continue
|
||||
name = f"grad_cov_{dim}"
|
||||
if name not in state:
|
||||
state[f"grad_cov_count_{dim}"] = 0
|
||||
state[name] = torch.zeros(size, size, **kwargs)
|
||||
|
||||
grad_cov = state[name]
|
||||
# just sum up the stats without normalization; we will divide by
|
||||
# count in _estimate_projections()
|
||||
g = grad.transpose(dim, -1)
|
||||
g = g.reshape(-1, g.shape[-1])
|
||||
grad_cov.add_(torch.matmul(g.t(), g))
|
||||
count = g.shape[0]
|
||||
state[f"grad_cov_count_{dim}"] += count
|
||||
|
||||
# We don't care about the scalar factor, we will normalize when we
|
||||
# use it.
|
||||
grad_cov.mul_(beta2).add_(torch.matmul(g.t(), g))
|
||||
|
||||
def _estimate_projections(self,
|
||||
p: Tensor,
|
||||
@ -623,90 +614,56 @@ class NeutralGradient(Optimizer):
|
||||
param_reverse_cutoff)
|
||||
|
||||
|
||||
def _estimate_param_scale(self,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
param_pow: float,
|
||||
param_eps: float,
|
||||
param_rel_eps: float,
|
||||
param_rel_max: float,
|
||||
param_reverse_cutoff: float) -> None:
|
||||
|
||||
def _multiply_by_lr(self,
|
||||
grad: Tensor,
|
||||
state: dict,
|
||||
forward: bool) -> Tensor:
|
||||
"""
|
||||
This is called from _estimate_projections() after suitably "diagonalizing" bases
|
||||
have already been estimated and written in state[f"proj_{dim}"] for the
|
||||
non-trivial dimensions of `p`. This function computes a factored estimate
|
||||
of the parmeter variance, and uses this to apply scales to the rows of the
|
||||
projections.
|
||||
|
||||
See the documentation of _estimate_projections() for documentation of the
|
||||
args.
|
||||
Multiply the grad by a positive semi-definite matrix for each dimension,
|
||||
interpreted as a non-scalar learning-rate factor.
|
||||
"""
|
||||
# Rotate `p` to the diagonalize basis. At this point there is no scaling,
|
||||
# just an orthogonal transformation; this function is going to add the
|
||||
# scaling to state[f"proj_{dim}"]
|
||||
rotated_p = self._change_coordinates(p, state, forward=True)
|
||||
|
||||
params_sq = rotated_p**2
|
||||
params_sq.add_(param_eps*param_eps +
|
||||
param_rel_eps*param_rel_eps * params_sq.mean())
|
||||
|
||||
|
||||
params_sq_norm = params_sq
|
||||
|
||||
# param_diag_vars is the diagonals of the parameter variances
|
||||
# on each tensor dim, with an arbitrary scalar factor.
|
||||
param_diag_vars = [None] * p.ndim
|
||||
|
||||
for _ in range(3):
|
||||
# Iterate 3 times, this should be enough to converge.
|
||||
for dim in range(p.ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
# p will have at least one non-trivial dim.
|
||||
other_dims = [ i for i in range(p.ndim) if i != dim ]
|
||||
# Compute diagonal variance along this dimension
|
||||
# (this is after normalizing previous dimensions' variance)
|
||||
this_var = params_sq_norm.mean(dim=other_dims, keepdim=True)
|
||||
params_sq_norm = params_sq_norm / this_var
|
||||
|
||||
|
||||
if param_diag_vars[dim] == None:
|
||||
param_diag_vars[dim] = this_var.reshape(size)
|
||||
else:
|
||||
param_diag_vars[dim] *= this_var.reshape(size)
|
||||
|
||||
for dim in range(p.ndim):
|
||||
size = p.shape[dim]
|
||||
cur_grad = grad
|
||||
for dim in range(grad.ndim):
|
||||
size = grad.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
param_diag_var = param_diag_vars[dim]
|
||||
num_samples = p.numel() // size
|
||||
# don't apply this reverse_cutoff thing in situations where we can't get a reasonable estimate
|
||||
# of param_cov even with stats accumulation, due to the shape of the tensor.
|
||||
reverse_cutoff = (param_reverse_cutoff if num_samples > size // 4 else 1.0e+10)
|
||||
param_diag_var = self._smooth_param_diag_var(param_diag_var,
|
||||
param_pow,
|
||||
param_rel_eps,
|
||||
param_rel_max,
|
||||
reverse_cutoff)
|
||||
param_scale = param_diag_var ** 0.5
|
||||
proj = state[f"proj_{dim}"]
|
||||
|
||||
if proj.ndim == 1:
|
||||
proj *= param_scale
|
||||
M = state[f"lr_{dim}"]
|
||||
if M.ndim == 1:
|
||||
# We are treating this dimension diagonally because it is too big.
|
||||
M = M.reshape(size, *([1] * (ndim-1-dim)))
|
||||
cur_grad = cur_grad * M
|
||||
else:
|
||||
# scale the rows; dim 0 of `proj` corresponds to the
|
||||
# diagonalized space.
|
||||
proj *= param_scale.unsqueeze(-1)
|
||||
# the scalar factors on the various dimensions' projections are a little
|
||||
# arbitrary, and their product is arbitrary too. we normalize the scalar
|
||||
# factor outside this function, see calar_exp_avg_sq.
|
||||
cur_grad = self._multiply_on_dim(cur_grad, M, dim)
|
||||
return cur_grad
|
||||
|
||||
|
||||
# Reset the squared-gradient stats because we have changed the basis.
|
||||
state["exp_avg_sq"].fill_(0.0)
|
||||
state["scalar_exp_avg_sq"].fill_(0.0)
|
||||
def _multiply_by_grad_cov(self,
|
||||
grad: Tensor,
|
||||
state: dict,
|
||||
forward: bool) -> Tensor:
|
||||
"""
|
||||
Multiply the grad by a positive semi-definite matrix for each dimension,
|
||||
interpreted as a non-scalar learning-rate factor.
|
||||
"""
|
||||
cur_grad = grad
|
||||
for dim in range(grad.ndim):
|
||||
size = grad.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
M = state[f"grad_cov_{dim}"]
|
||||
if M.ndim == 1:
|
||||
# We are treating this dimension diagonally because it is too big.
|
||||
M = M.reshape(size, *([1] * (ndim-1-dim)))
|
||||
# Dividing by (M.mean() + 1.0e-20) is just arbitrary normalization,
|
||||
# to avoid getting very large or very small outputs.
|
||||
cur_grad = cur_grad * (M / (M.mean() + 1.0e-20))
|
||||
else:
|
||||
# Dividing by (M.diag().mean() + 1.0e-20) is just arbitrary normalization,
|
||||
# to avoid getting very large or very small outputs. We only need the result
|
||||
# up to an arbitrary scalar factor.
|
||||
cur_grad = self._multiply_on_dim(cur_grad, M, dim) / (M.diag().mean() + 1.0e-20)
|
||||
return cur_grad
|
||||
|
||||
|
||||
|
||||
@ -752,94 +709,19 @@ class NeutralGradient(Optimizer):
|
||||
cur_grad = self._multiply_on_dim(cur_grad, proj, dim, forward)
|
||||
return cur_grad
|
||||
|
||||
def _get_param_cov(self, p: Tensor, dim: int) -> Tensor:
|
||||
|
||||
|
||||
def _multiply_on_dim(self, x: Tensor, M: Tensor, dim: int) -> Tensor:
|
||||
"""
|
||||
Compute a covariance matrix for one dimension of a parameter tensor,
|
||||
estimated by treating the other dimensions as a batch index. If this would
|
||||
be rank-deficient or close to rank-deficient, make it full-rank by adding
|
||||
a random SPD matrix with what we think is a suitable scaling.
|
||||
|
||||
The idea is that we want the projections to be a little random in
|
||||
the rank-deficient case, or we may get projections that lead us to
|
||||
too-dispersed estimates of parameter variance (i.e. some dimensions
|
||||
with exactly-zero parameter covariance, which leads to a too-small
|
||||
scaling for the gradient update).
|
||||
|
||||
Returns:
|
||||
A non-singular covariance matrix of shape (p.shape[dim], p.shape[dim])
|
||||
"""
|
||||
p = p.transpose(dim, -1)
|
||||
p = p.reshape(-1, p.shape[-1])
|
||||
(num_outer_products, size) = p.shape
|
||||
|
||||
param_cov = torch.matmul(p.t(), p) / num_outer_products
|
||||
|
||||
#self._randomize_lowrank_cov(param_cov, num_outer_products,
|
||||
# p.shape)
|
||||
|
||||
return param_cov
|
||||
|
||||
def _randomize_lowrank_cov(self, cov: Tensor, rank: int, shape: torch.Size,
|
||||
min_rand: float = 0.0) -> None:
|
||||
"""
|
||||
If this covariance has too-low rank (based on our knowledge of how we computed it),
|
||||
add to it a random positive-semidefinite matrix to make sure it is full-rank.
|
||||
The reason for the randomness is that when we estimate the parameter covariance
|
||||
after the co-ordinate change, we won't get a very robust estimate if the
|
||||
axes have been chosen to align too precisely with the parameter covariance
|
||||
eigenvalues. E.g. we might get close-to-zero estimates of parameter covariance.
|
||||
|
||||
Args:
|
||||
cov: covariance matrix to randomize, of shape (size, size). Will be modified
|
||||
in-place (we will add to it)
|
||||
rank: the number of outer products that `cov` is a sum over. If
|
||||
this is much larger than `size`, we will do nothing.
|
||||
shape: supplied for debug
|
||||
"""
|
||||
# To be confident in our estimate of the covariance, we want `rank` (which
|
||||
# actually represents the number of outer products added together)
|
||||
# to be at least `required_rank`; otherwise we'll add a random component.
|
||||
size = cov.shape[0]
|
||||
required_rank = int(size * 1.2) + 1
|
||||
|
||||
rand_scale = min_rand
|
||||
if rank < required_rank:
|
||||
# This epsilon is to avoid division by zero in weird situations where the
|
||||
# params are exactly zero or we sample a zero matrix; the exact value is not
|
||||
# going to affect the performance.
|
||||
eps = 1.0e-20
|
||||
|
||||
# The following formula assumes that the "missing" outer products are
|
||||
# expected to be smaller than the ones that we do have, i.e. 0.2 the size.
|
||||
# We are trying to find the trace of the matrix we need to add,
|
||||
# i.e. how big it is relative to the trace of the existing matrix.
|
||||
missing_rank = (required_rank - rank)
|
||||
rand_scale = max(rand_scale, 0.2 * missing_rank / rank)
|
||||
if rand_scale > 0.0:
|
||||
R = torch.randn(size, size, device=cov.device, dtype=cov.dtype)
|
||||
R = torch.matmul(R, R.t()) # positive semidefinite random matrix
|
||||
rms_diag = (cov.diag() + 1.0e-20).sqrt()
|
||||
R = R * rms_diag * rms_diag.unsqueeze(-1)
|
||||
|
||||
R_scale = R.diag().mean() + 1.0e-20
|
||||
cov_scale = cov.diag().mean() + 1.0e-20
|
||||
if random.random() < 0.02:
|
||||
print(f"required_rank={required_rank}, rank={rank}, size={size}, R_scale={R_scale}, rand_scale={rand_scale}, shape={shape}")
|
||||
cov.add_(R, alpha=rand_scale * cov_scale / R_scale)
|
||||
|
||||
|
||||
def _multiply_on_dim(self, x: Tensor, proj: Tensor, dim: int,
|
||||
forward: bool = True) -> Tensor:
|
||||
"""
|
||||
Matrix-multiplies x by `proj` on x's dimension numbered `dim`
|
||||
Matrix-multiplies x by symmetric matrix `M` on x's dimension numbered `dim`
|
||||
Args:
|
||||
x: Tensor to multiply, of arbitrary shape
|
||||
proj: M matrix to multiply x by, of shape
|
||||
(x.shape[dim], x.shape[dim]), interpreted as
|
||||
(modified_coordinate, canonical_coordinate)
|
||||
M: symmetric matrix to multiply x by, of shape
|
||||
(x.shape[dim], x.shape[dim]).
|
||||
"""
|
||||
return torch.matmul(x.transpose(-1, dim),
|
||||
(proj.t() if forward else proj)).transpose(-1, dim)
|
||||
# Note: can use either M or M.t() here because M is symmetric
|
||||
return torch.matmul(x.transpose(-1, dim), M.t())
|
||||
|
||||
|
||||
|
||||
def _smooth_param_diag_var(self,
|
||||
@ -1024,7 +906,7 @@ class NeutralGradient(Optimizer):
|
||||
param_freqs = param_freqs.tolist()
|
||||
param_periods = param_periods.tolist()
|
||||
|
||||
logging.info(f"NeutralGradient._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}")
|
||||
logging.info(f"LearnedGradient._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}")
|
||||
print_info = random.random() < 0.01
|
||||
i = 0
|
||||
for p in group["params"]:
|
||||
@ -1685,10 +1567,10 @@ def _test_eve_cain():
|
||||
|
||||
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
||||
elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
|
||||
elif iter == 2: optim = NeutralGradient(m.parameters(), lr=0.04, max_fullcov_size=10,
|
||||
elif iter == 2: optim = LearnedGradient(m.parameters(), lr=0.04, max_fullcov_size=10,
|
||||
estimate_period=500, stats_steps=100,
|
||||
lr_for_speedup=0.02)
|
||||
elif iter == 3: optim = NeutralGradient(m.parameters(), lr=0.04, max_fullcov_size=1000,
|
||||
elif iter == 3: optim = LearnedGradient(m.parameters(), lr=0.04, max_fullcov_size=1000,
|
||||
estimate_period=500, stats_steps=100,
|
||||
lr_for_speedup=0.02)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user