draft, not working, will edit locally

This commit is contained in:
Daniel Povey 2022-06-30 15:35:26 +08:00
parent 0b811546f3
commit d64cb1cb48

View File

@ -24,150 +24,70 @@ from torch.optim import Optimizer
from icefall import diagnostics # only for testing code
import logging
class NeutralGradient(Optimizer):
class LearnedGradient(Optimizer):
"""
Implements 'Neutral Gradient' update (a play on "natural gradient"). The aim is to
multiply the gradient by a symmetric positive definite matrix that transforms
the covariance of gradients to be the same as the covariance of parameters.
(Then multiply by the learning rate). Since, at any given time, we only have
one parameter tensor, in a sense it doesn't make sense to speak of the covariance
matrix of parameters; but the idea is, we do this separately for every axis of
every tensor and view the indexes into all the other axes of that tensor, as
a collection of vectors that we can compute the covariance of. There may
not be enough such vectors to estimate a non-singular covariance matrix; in
any case, we smooth with a diagonal matrix. (If the dimension is too large,
we may just use the diagonal only).
Implements 'Learned Gradient' update. It's a learned positive definite learning-rate
matrix for each dimension of each tensor, that we multiply the derivative by, and
then scale the delta/param-change before the update.
Args:
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
lr: The learning rate. We will typically use a learning rate schedule that starts
at 0.03 and decreases over time, i.e. much higher than other common
optimizers.
betas: Momentum constants for regular momentum, and second-order (per-element) stats
meta_lr_scale: The learning rate for the learning-rate matrix, this is a scale on
the regular learning rate.
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad only for
scalar tensors. Must satisfy 0 < beta <= beta2 < 1.
scale_speed: This scales the learning rate for purposes of learning the overall scale
of each tensor
eps: An epsilon to prevent division by zero
param_eps: An epsilon on the rms value of the parameter, such that when the parameter
gets smaller than this we'll start using a fixed learning rate, not
decreasing with the parameter norm.
param_rel_eps: An epsilon that limits how different different parameter rms estimates can
be within the same tensor
param_rel_max: Limits how big we allow the parameter variance eigs to be relative to the
original mean. Setting this to a not-very-large value like 1 ensures that
we only apply the param scaling for smaller-than-average, not larger-than-average,
eigenvalues, which limits the dispersion of eigenvalues of the parameter tensors.
param_reverse_cutoff: Once the variance in a parameter direction becomes over param_reverse_cutoff times the
average variance, we start to train slower in this direction, which should help
avoid too much eigenvalue dispersion in the parameter matrices.
param_max: To prevent parameter tensors getting too large, we will clip elements to
-param_max..param_max.
max_fullcov_size: We will only use the full-covariance (not-diagonal) update for
dimension with size <= max_size; for those larger, we will use a diagonal
formula
estimate_period: Every `estimate_period` steps, we will re-estimate projections for
the dimensions we are not treating diagonally. Each tensor will have a separately
randomly-chosen batch index modulo `estimate_period` to do this, to
decrease maximum memory consumption.
stats_steps: The number of minibatches of stats that we accumulate for purpose of basis
changes. Must be less than estimate_period
param_pow: A scale 0 <= param_pow <= 1.0 where 1.0 means we try to fully normalize
the parameter scale for each dimension (i.e. try to learn with speed proportional
to the parameter rms) and 0.0 means no such normalization at all (only a scalar
per tensor). In any case we fully normalize the parameter scale at the
whole-tensor level, for non-scalar tensors.
grad_pow: A scale 0 <= grad_pow <= 1.0 where 1.0 means we fully normalize the
gradient magnitude.
grad_min_rand: Minimum proportion of random tensor that we add to the gradient covariance,
to make the gradient covariance have less effect.
lr_for_speedup: A learning rate that determines how much speedup we aim for by
accumulating gradients over multiple batches (the update is done after
different delays for different parameters, based on an estimate of their
importance). The target speedup factor will be (lr_for_speedup / lr).
speedup_first_step: The first step that we start doing the speedup.
speedup_recalibrate_period: The number of steps between recalibrating the speedup
(i.e. choosing the periodicity for updating each parameter tensor)
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll keep it >= this size)
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll keep it <= this size)
lr_grad_period: The periodicity, in steps, with which we compute the gradient w.r.t.
the learning-rate matrix. Must be more than 1, because we delay the
gradient when we do this, and doing this every time might contribute to
instability.
lr_est_period: The periodicity, in steps, with which we update the learning-rate matrix;
must be a multiple of lr_grad_period.
"""
def __init__(
self,
params,
lr=3e-02,
betas=(0.9, 0.98, 0.95),
meta_lr_scale=1.0,
betas=(0.9, 0.98),
scale_speed=0.1,
grad_eps=1e-10,
param_eps=1.0e-06,
param_rel_eps=1.0e-04,
param_rel_max=1.0,
param_reverse_cutoff=4.0,
param_max_rms=2.0,
eps=1.0e-08,
param_min_rms=1.0e-05,
max_fullcov_size=1023,
estimate_period=2000,
stats_steps=200,
param_pow=0.5,
grad_pow=1.0,
grad_min_rand=0.0,
lr_for_speedup=0.03,
speedup_recalibrate_period=100,
speedup_first_step=500,
param_max_rms=2.0,
lr_grad_period=4,
lr_est_period=20,
max_step_scale=5.0
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 < betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
if not 0 < scale_speed < 1.0:
raise ValueError("Invalid scale_speed value: {}".format(scale_speed))
if not 0.0 <= grad_eps < 0.1:
raise ValueError("Invalid grad_eps value: {}".format(grad_eps))
if not 0.0 < param_eps < 0.1:
raise ValueError("Invalid param_eps value: {}".format(param_eps))
if not 0.1 < param_max_rms < 1.0e+10:
raise ValueError("Invalid param_max_rms value: {}".format(param_max_rms))
if not 0.0 < param_min_rms < 0.5:
raise ValueError("Invalid param_min_rms value: {}".format(param_min_rms))
if not 0 < estimate_period:
raise ValueError("Invalid estimate_period value: {}".format(estimate_period))
if not 0 < stats_steps < estimate_period:
raise ValueError("Invalid stats_steps value: {}".format(stats_steps))
if not 0 < speedup_first_step:
raise ValueError("Invalid speedup_first_step value: {}".format(speedup_first_step))
if not 0 < speedup_recalibrate_period:
raise ValueError("Invalid speedup_recalibrate_period value: {}".format(speedup_recalibrate_period))
# TODO: check all args
assert 0 < beta < 1
assert lr_est_period % lr_grad_period == 0
defaults = dict(
lr=lr,
betas=betas,
meta_lr_scale=meta_lr_scale,
beta=beta,
grad_beta=grad_beta,
scale_speed=scale_speed,
grad_eps=grad_eps,
param_eps=param_eps,
param_rel_eps=param_rel_eps,
param_rel_max=param_rel_max,
param_reverse_cutoff=param_reverse_cutoff,
param_max_rms=param_max_rms,
param_min_rms=param_min_rms,
max_fullcov_size=max_fullcov_size,
estimate_period=estimate_period,
stats_steps=stats_steps,
param_pow=param_pow,
grad_pow=grad_pow,
grad_min_rand=grad_min_rand,
lr_for_speedup=lr_for_speedup,
speedup_recalibrate_period=speedup_recalibrate_period,
speedup_first_step=speedup_first_step,
speedup_step=-speedup_first_step,
eps=eps,
lr_est_period=lr_est_period,
max_step_scale=max_step_scale,
)
super(NeutralGradient, self).__init__(params, defaults)
super(LearnedGradient, self).__init__(params, defaults)
def __setstate__(self, state):
super(NeutralGradient, self).__setstate__(state)
super(LearnedGradient, self).__setstate__(state)
@torch.no_grad()
@ -185,39 +105,13 @@ class NeutralGradient(Optimizer):
for group in self.param_groups:
lr = group["lr"]
beta1, beta2, beta3 = group["betas"]
beta1, beta2 = group["betas"]
meta_lr_scale = group["meta_lr_scale"]
scale_speed = group["scale_speed"]
grad_eps = group["grad_eps"]
param_eps = group["param_eps"]
param_rel_eps = group["param_rel_eps"]
param_rel_max = group["param_rel_max"]
param_reverse_cutoff = group["param_reverse_cutoff"]
param_max_rms = group["param_max_rms"]
param_min_rms = group["param_min_rms"]
max_fullcov_size = group["max_fullcov_size"]
estimate_period = group["estimate_period"]
stats_steps = group["stats_steps"]
param_pow = group["param_pow"]
grad_pow = group["grad_pow"]
grad_min_rand = group["grad_min_rand"]
lr_for_speedup = group["lr_for_speedup"]
speedup_first_step = group["speedup_first_step"]
speedup_recalibrate_period = group["speedup_recalibrate_period"]
eps = group["eps"]
lr_est_period = group["lr_est_period"]
max_step_scale = group["max_step_scale"]
speedup_step = group["speedup_step"]
group["speedup_step"] = speedup_step + 1
recalibrate = (speedup_step >= 0 and
speedup_step % speedup_recalibrate_period == 0)
if recalibrate:
# param_prods will be an array of products, one per parameter tensor, equal to:
# E[grad . step],
# where "step" is the individual step, before momentum (i.e. computed as if beta1==0).
# We compute these in such a way as to reflect the expectation if there is no speedup
# (i.e. assuming we are doing the update on every step with no gradient accumulation).
# The idea is that this product reflects the importance of each parameter matrix,
# so that parameter matrices with larger values will need more frequent updates.
self._recalibrate_speedup(group)
for p in group["params"]:
if p.grad is None:
@ -227,121 +121,74 @@ class NeutralGradient(Optimizer):
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"NeutralGradient optimizer does not support sparse gradients"
"LearnedGradient optimizer does not support sparse gradients"
)
state = self.state[p]
# State initialization
if len(state) == 0:
# For speed, we will actually take a step every
# `step_period` times step() is called; we will store the
# gradient in state["grad"] until state["n_cached_grads"]
# reaches step_period, at which point we'll really do the
# update.
state["step_period"] = 1
state["n_cached_grads"] = 0
# Allocate these cached_grad members immediately even though
# we may not need them immediately, so that any OOMs will
# occur quickly.
state["cached_grad"] = torch.zeros_like(
# moving_grad is a moving average of the raw gradient. We do the accumulation
# on the input side because it makes it easier to get grads w.r.t. the learning rate matrices.
state["moving_grad"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["step"] = 0
# Parameter step (used to implement momentum).
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient value, all
# tensors have this.
state["exp_avg_sq"] = torch.zeros_like(p)
kwargs = {'device':p.device, 'dtype':p.dtype}
# grad_times_step is an estimate of the importance of each parameter, for purposes
# of determining step_period period (by which we takes steps only periodically, to save
# time)
state["grad_times_step"] = torch.zeros((), **kwargs)
if p.numel() > 1:
is_one_axis = (p.numel() in list(p.shape)) # e.g. a bias
# "scale" is just a per-tensor scale on the learning rate.
param_rms = (p**2).mean().sqrt().add_(param_min_rms)
# "scale" is just a per-tensor scalar that determines the
# rate of learning (in terms of the rms value of the change
# in parameter)
param_rms = (p**2).mean().sqrt().add_(eps)
state["scale"] = param_rms
# we learn the scale on the parameters in a way that
# also emulates Eve. scale_exp_avg_sq is the
# moving-average square of the gradient w.r.t. this
# scalar scale.
state["scale_exp_avg_sq"] = torch.zeros((), **kwargs)
if not is_one_axis:
# each parameter has a different random time, modulo estimate_period,
# to re-estimate the projections. "step_within_period" will increase by
# 1 on each step, and will be reset to 0 when it reaches estimate_period.
state["step_within_period"] = random.randint(0,
estimate_period-stats_steps)
state["scalar_exp_avg_sq"] = torch.zeros((), **kwargs)
used_scale = False
for dim in range(p.ndim):
size = p.shape[dim]
if size == 1 or size == p.numel():
# if size == p.numel(), is_one_axis == True above
continue
elif size > max_fullcov_size:
# diagonal only...
state[f"proj_{dim}"] = torch.ones(size, **kwargs)
else:
state[f"proj_{dim}"] = torch.eye(size, **kwargs)
# we will decay the param stats via beta3; they are only
# accumulated once every `estimate_period`
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
state[f"lr_{dim}"] = torch.eye(size, **kwargs)
if not used_scale:
param_rms = (p**2).mean().sqrt().add_(param_eps)
state[f"proj_{dim}"] *= param_rms
used_scale = True
state[f"grad_cov_count_{dim}"] = 0
# we will also be using
# state[f"grad_cov_{dim}"]
# but we'll allocate this and deallocate it as needed, for individual
# parameters, to save memory.
step_period = state["step_period"]
n_cached_grads = state["n_cached_grads"]
if step_period > 1 or n_cached_grads > 0:
cached_grad = state["cached_grad"]
n_cached_grads += 1
state["n_cached_grads"] = n_cached_grads
# the point of random_step_prob is to inject a little randomness
# into the timing of the updates, so they don't stay synchronized.
random_step_prob = 0.05 / step_period
if n_cached_grads >= step_period or random.random() < random_step_prob:
# We will continue to the code below, and do a step
grad += cached_grad
cached_grad.zero_()
state["n_cached_grads"] = 0
# grad_cov_{dim} is the covariance of gradients on this axis,
# treating all other axes as as a batch axis. This is needed
# only for purposes of computing the scalar factor on the
# learning rate; and, as a result of this, also contributes something
# to the gradient w.r.t. f"lr_{dim}", which is one reason
# why we allocate a variable to keep track of its moving average
# instead of just using a temporary and smoothing the scalar factor.
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
else:
cached_grad += grad
# Don't do a step this time.
continue
else:
# We are doing a step on every iteration.
n_cached_grads = 1 # will be used to scale gradients below.
state["exp_avg_sq"] = torch.zeros_like(p)
delta = state["delta"]
moving_grad = state["moving_grad"]
step = state["step"]
delta.mul_(beta1)
if p.numel() > 1:
is_one_axis = (p.numel() in list(p.shape)) # e.g. a bias
if p.numel() == 1:
# For the scalar case we just Adam.
self._scalar_update(beta1, beta2, lr, state)
continue
else:
if step % lr_est_period == 0:
self._step_with_update(group, p, grad, state)
else:
self._step_no_update(group, p, grad, state)
state["step"] = step + 1
if True:
# This block learns the scale on the parameters.
@ -363,7 +210,6 @@ class NeutralGradient(Optimizer):
scale_alpha = -lr * (scale_bias_correction2 ** 0.5) * scale_speed * (1-beta1)
enforce_rms_period = 2
scale_norm_deriv = (scale_deriv / scale_denom)
if step % enforce_rms_period == 0:
@ -493,10 +339,160 @@ class NeutralGradient(Optimizer):
return loss
def _scalar_update(self,
beta1: float,
beta2: float,
lr: float
state: dict):
assert False # TODO: implement
def _step_with_update(self,
group: dict,
p: Tensor,
grad: Tensor,
state: dict):
# meta_lr is the learning rate for the learning rate matrix.
lr = group["lr"]
meta_lr = lr * group["meta_lr_scale"]
beta1 = group["betas"][0]
eps = group["eps"]
moving_g = state["moving_grad"]
moving_g.mul_(beta1)
ndim = p.ndim
# Update grad_cov
self._store_grad_stats(grad, state)
for dim in range(ndim):
if p.shape[dim] != 1:
state[f"lr_{dim}"].requires_grad = True
with torch.enable_grad():
# set requires_grad to True on state[f"lr_{dim}"]
for dim in range(ndim):
if p.shape[dim] != 1:
state[f"lr_{dim}"].requires_grad = True
moving_d = self._multiply_by_lr(moving_g)
# moving_d_rms is an rms value of moving_d computed via inner products with
# the grad, which is a more basis-independent way of computing the rms value.
# you have to view this as the rms times an arbitrary multiplicative factor.
moving_d_rms = (self._multiply_by_grad_cov(moving_d) * moving_d).mean().sqrt()
# moving_d_norm is the parameter "delta" moving_d, renormalized via an inner
# product with the gradient covariances (for basis independence), but still with
# an arbitrary multiplicative constant. We don't care about this constant
# because we are going to normalize it out anyway.
moving_d_norm = moving_d / moving_d_rms
# view the parameter p as [something] - alpha * moving_d_norm, where alpha contains
# the learning rate and other factors, but we don't care about its value
# alpha because we're going to divide by the gradient norm anyway.
# The (pseudo-) loss function would be (param * grad).sum(), which ignoring
# [something], is -alpha * (moving_d_norm * grad).sum(), and ignoring
# alpha that is -(moving_d_norm * grad).sum(), so we call this neg_loss.
neg_loss = (grad * moving_d_norm).sum()
# now there will grads in the state[f"lr_{dim}"] that are the negative of loss
# function grads, i.e. we want to go in the forward grad direction.
neg_loss.backward()
for dim in range(ndim):
if p.shape[dim] != 1:
self._update_lr(state[f"lr_{dim}"], meta_lr)
# moving_d contains a moving average of gradients, (1-beta) * (beta + beta^2 + beta^3, etc.)
# Suppose the rms value of each individual gradient were r. Assuming these independent,
# the variance of the elements of moving_d would be r^2 * (1-beta)^2 * (beta^2 + beta^4 + ... )
# ((The reason why the series in parentheses starts with beta is because on "update iters",
# we delay the current grad term (`grad`), which would normally appear with coefficient 1,
# until the next iter))
# which equals r^2 * (1-beta)^2 * beta^2 / (1-beta^2), so the measured rms of moving_d would be
# moving_d_rms = individual_d_rms * (1-beta) beta / sqrt(1-beta^2), so
# individual_d_rms = moving_d_rms * sqrt(1-beta^2) / ((1-beta)beta)
#... this would be the root-mean-square value of the individual deltas, if we were to proces
# them individually (if they were independent), and this is what we want to normalize by for
# greater independence of the beta value, i.e. so beta only affects momentum and not the
# overall speed
moving_d_rms = (moving_d**2).mean().sqrt() + eps
# e.g. if beta1==0.98, this formula gives a factcor of 9.95.
d_rms = moving_d_rms * (((1-beta1*beta1)**0.5) / (1-beta1))
alpha = -lr / d_rms
p.add_(moving_d, alpha=alpha)
def _update_lr(self,
lr_mat: Tensor,
meta_lr: float) -> None:
"""
Do one step of update for learning rate matrix 'lr_mat', which we assume has already
has its `grad` property populated with the grad of the negative loss (our special
loss that we use to train the learning rate matrix). Also delete lr_mat.grad
Args:
lr_mat: The symmetric positive-semidefinite (PSD) learning-rate matrix.
Will have its .grad populated with the derivative of a negated
loss function (i.e. to be maximized), that we use to optimize
this learning-rate matrix. lr_mat.grad will be deleted and
lr_mat will be updated by one step.
meta_lr: The speed with which we learn lr_mat.
"""
grad = lr_mat.grad
del lr_mat.grad
lr_mat.requires_grad = False
# OK. suppose lr_mat = M, which is symmetric positive definite.
# Decompose: M = A N A^T
# where N is numerically equal to M (but a separate variable), and A is a variable numerically
# equal to I. To make it easier to keep the result positive definite, we learn A rather than N.
# We'll use a convention where A_grad has the same orientation as A, i.e. A_grad_ij equals the grad
# of loss w.r.t A_ij.
# The grad w.r.t A equals (N M_grad)^T + (M_grad N) = 2 (M_grad N). We don't care about scalar
# factors at this point (we'll normalize them out), so ignore the factor of 2, and say;
# A_grad = M_grad N.
A_grad = torch.matmul(grad, lr_mat)
# Normalize the size of `A_grad` (there are arbitrary scalar factors in this loss function),
# and include meta_lr
A_grad = A_grad * (meta_lr / ((A_grad ** 2).mean().sqrt() + 1.0e-20))
# So the updated A is just I plus lr_mat * normalize[A_grad]
size = A.shape[0]
A = torch.eye(size, device=A.device, dtype=A.dtype) + meta_lr * A_grad
# the result is positive semidefinite if the initial LR_mat was positive semidefinite,
# because it is of the form X X^T where X = A lr_mat^{0.5}
torch.matmul(A, torch.matmul(lr_mat, A.t()), out=lr_mat)
# Normalize the scale oflr_matA: we always ensure lr_mat.diag().mean() is 1.0
lr_mat *= 1.0 / lr_mat.diag().mean()
def _step_no_update(self,
group: dict,
p: Tensor,
grad: Tensor,
state: dict):
pass
def _store_grad_stats(self,
grad: Tensor,
state: dict,
max_fullcov_size: int) -> None:
beta2: float) -> None:
"""
Accumulate some stats for each dimension of the tensor, about the covariance of gradients.
The stats are just summed, not decayed; we will re-estimate the projections periodically,
@ -506,23 +502,18 @@ class NeutralGradient(Optimizer):
kwargs = {'device': grad.device, 'dtype': grad.dtype}
for dim in range(ndim):
size = grad.shape[dim]
is_fullcov = (size <= max_fullcov_size)
if size == 1 or not is_fullcov:
if size == 1:
continue
name = f"grad_cov_{dim}"
if name not in state:
state[f"grad_cov_count_{dim}"] = 0
state[name] = torch.zeros(size, size, **kwargs)
grad_cov = state[name]
# just sum up the stats without normalization; we will divide by
# count in _estimate_projections()
g = grad.transpose(dim, -1)
g = g.reshape(-1, g.shape[-1])
grad_cov.add_(torch.matmul(g.t(), g))
count = g.shape[0]
state[f"grad_cov_count_{dim}"] += count
# We don't care about the scalar factor, we will normalize when we
# use it.
grad_cov.mul_(beta2).add_(torch.matmul(g.t(), g))
def _estimate_projections(self,
p: Tensor,
@ -623,90 +614,56 @@ class NeutralGradient(Optimizer):
param_reverse_cutoff)
def _estimate_param_scale(self,
p: Tensor,
state: dict,
param_pow: float,
param_eps: float,
param_rel_eps: float,
param_rel_max: float,
param_reverse_cutoff: float) -> None:
def _multiply_by_lr(self,
grad: Tensor,
state: dict,
forward: bool) -> Tensor:
"""
This is called from _estimate_projections() after suitably "diagonalizing" bases
have already been estimated and written in state[f"proj_{dim}"] for the
non-trivial dimensions of `p`. This function computes a factored estimate
of the parmeter variance, and uses this to apply scales to the rows of the
projections.
See the documentation of _estimate_projections() for documentation of the
args.
Multiply the grad by a positive semi-definite matrix for each dimension,
interpreted as a non-scalar learning-rate factor.
"""
# Rotate `p` to the diagonalize basis. At this point there is no scaling,
# just an orthogonal transformation; this function is going to add the
# scaling to state[f"proj_{dim}"]
rotated_p = self._change_coordinates(p, state, forward=True)
params_sq = rotated_p**2
params_sq.add_(param_eps*param_eps +
param_rel_eps*param_rel_eps * params_sq.mean())
params_sq_norm = params_sq
# param_diag_vars is the diagonals of the parameter variances
# on each tensor dim, with an arbitrary scalar factor.
param_diag_vars = [None] * p.ndim
for _ in range(3):
# Iterate 3 times, this should be enough to converge.
for dim in range(p.ndim):
size = p.shape[dim]
if size == 1:
continue
# p will have at least one non-trivial dim.
other_dims = [ i for i in range(p.ndim) if i != dim ]
# Compute diagonal variance along this dimension
# (this is after normalizing previous dimensions' variance)
this_var = params_sq_norm.mean(dim=other_dims, keepdim=True)
params_sq_norm = params_sq_norm / this_var
if param_diag_vars[dim] == None:
param_diag_vars[dim] = this_var.reshape(size)
else:
param_diag_vars[dim] *= this_var.reshape(size)
for dim in range(p.ndim):
size = p.shape[dim]
cur_grad = grad
for dim in range(grad.ndim):
size = grad.shape[dim]
if size == 1:
continue
param_diag_var = param_diag_vars[dim]
num_samples = p.numel() // size
# don't apply this reverse_cutoff thing in situations where we can't get a reasonable estimate
# of param_cov even with stats accumulation, due to the shape of the tensor.
reverse_cutoff = (param_reverse_cutoff if num_samples > size // 4 else 1.0e+10)
param_diag_var = self._smooth_param_diag_var(param_diag_var,
param_pow,
param_rel_eps,
param_rel_max,
reverse_cutoff)
param_scale = param_diag_var ** 0.5
proj = state[f"proj_{dim}"]
if proj.ndim == 1:
proj *= param_scale
M = state[f"lr_{dim}"]
if M.ndim == 1:
# We are treating this dimension diagonally because it is too big.
M = M.reshape(size, *([1] * (ndim-1-dim)))
cur_grad = cur_grad * M
else:
# scale the rows; dim 0 of `proj` corresponds to the
# diagonalized space.
proj *= param_scale.unsqueeze(-1)
# the scalar factors on the various dimensions' projections are a little
# arbitrary, and their product is arbitrary too. we normalize the scalar
# factor outside this function, see calar_exp_avg_sq.
cur_grad = self._multiply_on_dim(cur_grad, M, dim)
return cur_grad
# Reset the squared-gradient stats because we have changed the basis.
state["exp_avg_sq"].fill_(0.0)
state["scalar_exp_avg_sq"].fill_(0.0)
def _multiply_by_grad_cov(self,
grad: Tensor,
state: dict,
forward: bool) -> Tensor:
"""
Multiply the grad by a positive semi-definite matrix for each dimension,
interpreted as a non-scalar learning-rate factor.
"""
cur_grad = grad
for dim in range(grad.ndim):
size = grad.shape[dim]
if size == 1:
continue
M = state[f"grad_cov_{dim}"]
if M.ndim == 1:
# We are treating this dimension diagonally because it is too big.
M = M.reshape(size, *([1] * (ndim-1-dim)))
# Dividing by (M.mean() + 1.0e-20) is just arbitrary normalization,
# to avoid getting very large or very small outputs.
cur_grad = cur_grad * (M / (M.mean() + 1.0e-20))
else:
# Dividing by (M.diag().mean() + 1.0e-20) is just arbitrary normalization,
# to avoid getting very large or very small outputs. We only need the result
# up to an arbitrary scalar factor.
cur_grad = self._multiply_on_dim(cur_grad, M, dim) / (M.diag().mean() + 1.0e-20)
return cur_grad
@ -752,94 +709,19 @@ class NeutralGradient(Optimizer):
cur_grad = self._multiply_on_dim(cur_grad, proj, dim, forward)
return cur_grad
def _get_param_cov(self, p: Tensor, dim: int) -> Tensor:
def _multiply_on_dim(self, x: Tensor, M: Tensor, dim: int) -> Tensor:
"""
Compute a covariance matrix for one dimension of a parameter tensor,
estimated by treating the other dimensions as a batch index. If this would
be rank-deficient or close to rank-deficient, make it full-rank by adding
a random SPD matrix with what we think is a suitable scaling.
The idea is that we want the projections to be a little random in
the rank-deficient case, or we may get projections that lead us to
too-dispersed estimates of parameter variance (i.e. some dimensions
with exactly-zero parameter covariance, which leads to a too-small
scaling for the gradient update).
Returns:
A non-singular covariance matrix of shape (p.shape[dim], p.shape[dim])
"""
p = p.transpose(dim, -1)
p = p.reshape(-1, p.shape[-1])
(num_outer_products, size) = p.shape
param_cov = torch.matmul(p.t(), p) / num_outer_products
#self._randomize_lowrank_cov(param_cov, num_outer_products,
# p.shape)
return param_cov
def _randomize_lowrank_cov(self, cov: Tensor, rank: int, shape: torch.Size,
min_rand: float = 0.0) -> None:
"""
If this covariance has too-low rank (based on our knowledge of how we computed it),
add to it a random positive-semidefinite matrix to make sure it is full-rank.
The reason for the randomness is that when we estimate the parameter covariance
after the co-ordinate change, we won't get a very robust estimate if the
axes have been chosen to align too precisely with the parameter covariance
eigenvalues. E.g. we might get close-to-zero estimates of parameter covariance.
Args:
cov: covariance matrix to randomize, of shape (size, size). Will be modified
in-place (we will add to it)
rank: the number of outer products that `cov` is a sum over. If
this is much larger than `size`, we will do nothing.
shape: supplied for debug
"""
# To be confident in our estimate of the covariance, we want `rank` (which
# actually represents the number of outer products added together)
# to be at least `required_rank`; otherwise we'll add a random component.
size = cov.shape[0]
required_rank = int(size * 1.2) + 1
rand_scale = min_rand
if rank < required_rank:
# This epsilon is to avoid division by zero in weird situations where the
# params are exactly zero or we sample a zero matrix; the exact value is not
# going to affect the performance.
eps = 1.0e-20
# The following formula assumes that the "missing" outer products are
# expected to be smaller than the ones that we do have, i.e. 0.2 the size.
# We are trying to find the trace of the matrix we need to add,
# i.e. how big it is relative to the trace of the existing matrix.
missing_rank = (required_rank - rank)
rand_scale = max(rand_scale, 0.2 * missing_rank / rank)
if rand_scale > 0.0:
R = torch.randn(size, size, device=cov.device, dtype=cov.dtype)
R = torch.matmul(R, R.t()) # positive semidefinite random matrix
rms_diag = (cov.diag() + 1.0e-20).sqrt()
R = R * rms_diag * rms_diag.unsqueeze(-1)
R_scale = R.diag().mean() + 1.0e-20
cov_scale = cov.diag().mean() + 1.0e-20
if random.random() < 0.02:
print(f"required_rank={required_rank}, rank={rank}, size={size}, R_scale={R_scale}, rand_scale={rand_scale}, shape={shape}")
cov.add_(R, alpha=rand_scale * cov_scale / R_scale)
def _multiply_on_dim(self, x: Tensor, proj: Tensor, dim: int,
forward: bool = True) -> Tensor:
"""
Matrix-multiplies x by `proj` on x's dimension numbered `dim`
Matrix-multiplies x by symmetric matrix `M` on x's dimension numbered `dim`
Args:
x: Tensor to multiply, of arbitrary shape
proj: M matrix to multiply x by, of shape
(x.shape[dim], x.shape[dim]), interpreted as
(modified_coordinate, canonical_coordinate)
M: symmetric matrix to multiply x by, of shape
(x.shape[dim], x.shape[dim]).
"""
return torch.matmul(x.transpose(-1, dim),
(proj.t() if forward else proj)).transpose(-1, dim)
# Note: can use either M or M.t() here because M is symmetric
return torch.matmul(x.transpose(-1, dim), M.t())
def _smooth_param_diag_var(self,
@ -1024,7 +906,7 @@ class NeutralGradient(Optimizer):
param_freqs = param_freqs.tolist()
param_periods = param_periods.tolist()
logging.info(f"NeutralGradient._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}")
logging.info(f"LearnedGradient._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}")
print_info = random.random() < 0.01
i = 0
for p in group["params"]:
@ -1685,10 +1567,10 @@ def _test_eve_cain():
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
elif iter == 2: optim = NeutralGradient(m.parameters(), lr=0.04, max_fullcov_size=10,
elif iter == 2: optim = LearnedGradient(m.parameters(), lr=0.04, max_fullcov_size=10,
estimate_period=500, stats_steps=100,
lr_for_speedup=0.02)
elif iter == 3: optim = NeutralGradient(m.parameters(), lr=0.04, max_fullcov_size=1000,
elif iter == 3: optim = LearnedGradient(m.parameters(), lr=0.04, max_fullcov_size=1000,
estimate_period=500, stats_steps=100,
lr_for_speedup=0.02)
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)