Implement new version of learning method. Does more complete diagonalization of grads than the previous methods.

This commit is contained in:
Daniel Povey 2022-07-09 10:02:17 +08:00
parent a9edecd32c
commit 6810849058

View File

@ -24,10 +24,10 @@ from torch.optim import Optimizer
from icefall import diagnostics # only for testing code
import logging
class LearnedGradient(Optimizer):
class PrAdam(Optimizer):
"""
Implements 'Learned Gradient' update. It's special factorization of the
parameter matrix..
Implements 'Projected Adam', a variant of Adam with projections for each
dim of each tensor.
Args:
@ -35,63 +35,70 @@ class LearnedGradient(Optimizer):
lr: The learning rate. We will typically use a learning rate schedule that starts
at 0.03 and decreases over time, i.e. much higher than other common
optimizers.
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad only for
scalar tensors. Must satisfy 0 < beta <= beta2 < 1.
size_lr_scale: A scaling factor on the learning rate, that we use to update the
scale of each parameter tensor. If each parameter were decomposed
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, size_lr_scale
is the scaling factor on the learning rate of p_scale.
meta_lr_scale: The learning rate for the learning-rate matrix, this is a scale on
the regular learning rate.
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad only for
scalar tensors. Must satisfy 0 < beta <= beta2 < 1.
param_pow: Power on the parameter covariance matrix, 1.0 means learn proportional
to parameter rms (1.0 will be too much, should be between 0 and 1.)
param_rms_smooth0: Limiting value of smoothing proportion for parameter matrix, as
assumed rank of param covariance [==product of sizes on the other
tensor dims] approaches 0.
param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
param covariance equals the dimension of the covaraince matrix.
param_rms_smooth{0,1} determine the smoothing proportions for other
conditions.
param_cov_freshness: Constant that determines how "recent" the parameter covariance
matrix is. 1.0 corresponds to flat average over time; 2.0
would mean weighting with a quadratic function, etc.
eps: An epsilon to prevent division by zero
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll keep it >= this size)
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll keep it <= this size)
lr_mat_min: Minimum singular value of learning-rate matrices Q.
lr_mat_max: Maximum singular value of learning-rate matrices Q.
lr_est_period: The periodicity, in steps, with which we update the learning-rate matrices.
diagonalize_period: How often we re-diagonalize the gradient covariance; this
gets multiplied by lr_est_period, i.e. we diagonalize less frequently
than we update the learning-rate matrixes.
size_update_period: The periodicity, in steps, with which we update the size (scale)
of the parameter tensor. This is provided to save a little time.
lr_update_period: The periodicity, in steps, with which we update the learning-rate matrices.
"""
def __init__(
self,
params,
lr=3e-02,
size_lr_scale=0.1,
meta_lr_scale=0.0,
betas=(0.9, 0.98),
size_lr_scale=0.1,
param_pow=0.75,
param_rms_smooth0=0.75,
param_rms_smooth1=0.25,
param_cov_freshness=1.0,
eps=1.0e-08,
size_update_period=1,
param_min_rms=1.0e-05,
param_max_rms=2.0,
lr_mat_min=0.01,
lr_mat_max=10.0,
lr_est_period=2,
diagonalize_period=4,
size_update_period=4,
lr_update_period=20,
):
defaults = dict(
lr=lr,
size_lr_scale=size_lr_scale,
meta_lr_scale=meta_lr_scale,
param_pow=param_pow,
param_rms_smooth0=param_rms_smooth0,
param_rms_smooth1=param_rms_smooth1,
param_cov_freshness=param_cov_freshness,
betas=betas,
eps=eps,
size_update_period=size_update_period,
param_min_rms=param_min_rms,
param_max_rms=param_max_rms,
lr_mat_min=lr_mat_min,
lr_mat_max=lr_mat_max,
lr_est_period=lr_est_period,
diagonalize_period=diagonalize_period,
size_update_period=size_update_period,
lr_update_period=lr_update_period,
)
super(LearnedGradient, self).__init__(params, defaults)
super(PrAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(LearnedGradient, self).__setstate__(state)
super(PrAdam, self).__setstate__(state)
@torch.no_grad()
@ -111,14 +118,11 @@ class LearnedGradient(Optimizer):
lr = group["lr"]
size_lr = lr * group["size_lr_scale"]
beta1, beta2 = group["betas"]
lr_mat_min = group["lr_mat_min"]
lr_mat_max = group["lr_mat_max"]
eps = group["eps"]
size_update_period = group["size_update_period"]
param_min_rms = group["param_min_rms"]
param_max_rms = group["param_max_rms"]
lr_est_period = group["lr_est_period"]
diagonalize_period = group["diagonalize_period"]
lr_update_period = group["lr_update_period"]
for p in group["params"]:
if p.grad is None:
@ -128,7 +132,7 @@ class LearnedGradient(Optimizer):
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"LearnedGradient optimizer does not support sparse gradients"
"PrAdam optimizer does not support sparse gradients"
)
state = self.state[p]
@ -177,29 +181,14 @@ class LearnedGradient(Optimizer):
continue
# Q_{dim}, which we will call q in this comment, will be the learning-rate matrix for this
# dimension, times an orthogonal matrix that we'll periodically
# re-estimate when we "rediagonalize" (this diagonalizes
# the gradient covariance).
# If our parameter matrix M is of shape (..., size),
# we'll view M as being M == torch.matmul(N, Q)
# so p can be interpreted as having shape
# (size, size), interpreted as being indexed [diagonalized_coordinate, canonical_coordinate]
# by "diagonalize" we mean that we diagonalize the gradient covariance.
# Q_{dim} is be the learning-rate matrix for this
# dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate].
# this is used twice in the update step, once transposed and once without transpose.
state[f"Q_{dim}"] = torch.eye(size, **kwargs)
# proj_grad_{dim} is the gradient w.r.t. `proj`,
# which is a matrix we introduce to the right of Q, i.e.
# instead of M == torch.matmul(N, Q) we view it as
# M == torch.matmul(torch.matmul(N, torch.matmul(Q, proj)))
# or equivalently M == torch.matmul(M_underlying, proj)
# where M_underlying is numerically equal to M, and proj is numerically
# equal to I, but they are viewed as separate variables.
#
# proj_grad_{dim} will be set to the sum of
# torch.matmul(M^T, M_grad) over a number of steps (`lr_est_period` steps),
# and will be zeroed after we call self_update_lrs().
state[f"proj_grad_{dim}"] = torch.zeros(size, size, **kwargs)
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
# all other dims as a batch axis.
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
# grad_cov_{dim} is the covariance of gradients on this axis (without
# any co-ordinate changes), treating all other axes as as a batch axis.
@ -213,14 +202,6 @@ class LearnedGradient(Optimizer):
# instead of just using a temporary and smoothing the scalar factor.
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
# proj_grad_var_{dim} is the variance of gradients of the thing we're
# going to multiply proj_{dim} by (between it and the "underlying
# parameter matrix"), aggregated over each of its 2 dimensions
# and decayed over time with beta1 each time we do a step on proj_{dim},
# e.g. every lr_est_period steps. It's a factored representation
# of something like exp_avg_sq, used for estimating proj_{dim}; we avoid
# storing an exp_avg_sq for proj_{dim}, to save GPU memory.
state[f"proj_grad_var_{dim}"] = torch.ones(2, size, **kwargs)
step = state["step"]
@ -250,10 +231,9 @@ class LearnedGradient(Optimizer):
# parameter rms. Updates delta.
self._step_scalar(beta1, beta2, eps, lr, p, grad, state)
else:
if step % lr_est_period == 0 and step > 0:
if step % lr_update_period == 0 and step > 0:
self._accum_param_covs(group, p, state)
self._update_lrs(group, p, state)
if step % (lr_est_period * diagonalize_period) == 0 and step > 0:
self._diagonalize_lrs(group, p, state)
self._zero_exp_avg_sq(state)
self._step(group, p, grad, state)
p.add_(delta)
@ -326,31 +306,19 @@ class LearnedGradient(Optimizer):
# the factor of (1-beta1) relates to momentum.
delta.add_(p, alpha=(1-beta1) * scale_step)
def _update_lrs(self,
group: dict,
p: Tensor,
state: dict) -> None:
def _accum_param_covs(self,
group: dict,
p: Tensor,
state: dict) -> None:
"""
Updates the learning-rate matrices state["Q_{dim}"].
Args:
group: dict to look up configuration values
p: parameter matrix that we are updating. The learning rate matrices
are actually factors of p, so p itself will change when we change
them.
state: state dict for the current parameter
Updates state[f"param_cov_{dim}"] for each dim of tensor p.
"""
# meta_lr is the learning rate for the learning rate matrix. Including the factor
# (group["lr_est_period"] ** 0.5) is intended to make it approximately invariant
# to the lr_est_period (close to convergence).
meta_lr = group["lr"] * group["meta_lr_scale"] * (group["lr_est_period"] ** 0.5)
beta1 = group["betas"][0]
eps = group["eps"]
delta = state["delta"]
ndim = p.ndim
for dim in range(ndim):
lr_update_period = group["lr_update_period"]
step = state["step"]
this_weight = (group["param_cov_freshness"] * lr_update_period /
(step + lr_update_period))
for dim in range(p.ndim):
size = p.shape[dim]
if size == 1:
continue
@ -359,153 +327,14 @@ class LearnedGradient(Optimizer):
# all other dimensions of the tensor.
M_full = p.transpose(dim, -1)
M = M_full.reshape(-1, size)
this_param_cov = torch.matmul(M.t(), M)
# normalize scale of this param_cov, in casre parameter scale
# changes significantly during training, which would cause some
# parts of the training timeline to be more highly weighted.
this_param_cov /= (this_param_cov.diag().mean() + eps)
state[f"param_cov_{dim}"].mul_(1-this_weight).add_(this_param_cov,
alpha=this_weight)
# proj_grad is a summed gradient, of shape (size, size),
# indexed [old_param_idx][new_param_idx] (the point is,
# param_idx corresponds to the index of the current M, new_param_idx
# corresponds to the index of the M after a small change).
#
# Suppose M is our parameter matrix p reshaped to (-1, size) with
# the -1 being all other dims of the tensor treated as a batch dim
# and "size" is the size of whatever dimension we are working on
# right now. And suppose M_grad is the loss derivative w.r.t. M
# (for compatibility with things like Torch, we use a convention
# where all loss derivatives/grads have the same shape as the things they
# are derivative with respect to).
#
# proj_grad is the accumulated value, on the last few steps, of M M_grad^T,
# which will be of shape (size, size). For purposes of our update, if we view
# the parameter matrix M as (M_underlying proj) with
# proj being numerically equal to I but treated as a variable, i.e. as
# M = matmul(M_underlying, proj), then instead of the loss being tr(M_grad^T M), we can write
# it as tr(M_grad^T M_underlying proj), which can also be written as tr(proj_grad^T proj)
# where proj_grad == (M_grad^T M_underlying)^T == M_underlying^T M_grad == M^T M_grad
#
# Now, proj_grad is convenient to compute, being independent of Q,
# but it's not in a very normalized
# space; we want to normalize it before doing gradient descent on it.
# We're going to view M as
# M == N Q,
# where Q is our Q_{dim} "learning-rate" matrix, of shape (size, size) indexed
# [diagonalized_index, canonical_index] (see the code), and N == M Q^{-1}.
# It's likely better numerically to do gradient descent on a `proj2` matrix
# that's between N and Q,
# i.e.:
# M == N proj2 Q
# We want to obtain the derivative w.r.t. proj2.
# So the linearized pseudo-loss is
# tr(M_grad^T M) == tr(M_grad^T N proj2 Q) == tr(Q M_grad^T N proj2),
# which we can view as tr(proj2_grad^T proj2), where
# proj2_grad == (Q M_grad^T N)^T == N^T M_grad Q^T == Q^{-T} M^T M_grad Q^T
# proj2_grad = Q^{-T} proj_grad Q^T (eq.1)
#
# After computing proj2_grad we compute/update a factorized representation of
# its variance, factorized over its 2 axes, and do an Adam-type update
# for it, except without momentum at this level (momentum will be applied
# later). This gives us proj2_delta, according to the
# equation:
# proj2 = I + proj2_delta
# (again: proj2_delta is derived from proj2_grad using an Adam-style update).
#
# We now have to update the "learning-rate matrix" Q using proj2, as in
# Q := proj2 Q
# Q := (proj2_delta + I) Q, or:
# Q += proj2_delta Q (eq.2)
#
# We also need to update M itself; this is easiest done by computing proj_delta
# (which is the difference of proj from I). The relationship between proj and proj2
# is given by equating
# M == N proj2 Q == N Q proj,
# so proj2 Q == Q proj
# proj == Q^{-1} proj2 Q
# and subtracting out the constant "I" part,
# proj_delta == Q^{-1} proj2_delta Q (eq.3)
#
# and then because we view the parameter matrix as
# "M proj",
# with proj being numerically equal to I but learned here,
# it becomes M (I + proj_delta) = M + M proj_delta.
# So the update to M is:
# M += M proj_delta
# or equivalently:
# M_delta = M proj_delta (eq.5)
proj_grad = state[f"proj_grad_{dim}"]
Q = state[f"Q_{dim}"]
# delta_scale < 1 will make it update the learning rates faster than it otherwise would,
# as we'll reach equilbrium with M less rapidly.
delta_scale=0.5
if True:
# Simpler update, do it to the right of Q.
# symmetrize, otherwise there is a bad interaction with the regular update.
proj_grad[:] = proj_grad + proj_grad.t()
proj_grad_var = self._get_matrix_var(proj_grad, group["eps"])
denom = proj_grad_var.sqrt()
proj_delta = proj_grad / denom
Q_delta = torch.matmul(Q, proj_delta)
M_delta = torch.matmul(M, proj_delta)
M_delta_full = M_delta.reshape(M_full.shape)
this_delta = M_delta_full.transpose(dim, -1)
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
# there is no momentum on Q.
Q.add_(Q_delta, alpha=-meta_lr)
proj_grad.zero_()
continue
# Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
# torch.solve(A, B) returns A^{-1} B.
try:
# not all versions of Torch have this.
Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad)
except:
# torch.solve has args the other way round and also returns LU.
Q_invt_proj_grad, _ = torch.solve(proj_grad, Q.t())
# (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t())
# symmetrize, otherwise there is a bad interaction with the regular
# update in self._step().
proj2_grad = proj2_grad + proj2_grad.t()
# for now just estimate proj2_grad_var from proj2_grad; later, we
# may try aggregating it across iterations. This is like exp_avg_sq
# in Adam.
proj2_grad_var = self._get_matrix_var(proj2_grad, group["eps"])
denom = proj2_grad_var.sqrt()
# we'll take into account the factor of -meta_lr * (1-beta1) later on;
# actually proj2_delta is the negative of a parameter change.
proj2_delta = proj2_grad / denom
# See (eq.2), Q += proj2_delta Q
Q_delta = torch.matmul(proj2_delta, Q)
# See (eq.3), proj_delta = Q^{-1} proj2_delta Q = Q^{-1} Q_delta
try:
proj_delta = torch.linalg.solve(Q, Q_delta)
except:
# torch.solve has args the other way round and also returns LU.
proj_delta, _ = torch.solve(Q_delta, Q)
M_delta = torch.matmul(M, proj_delta) # (eq.5), M_delta = M proj_delta
M_delta_full = M_delta.reshape(M_full.shape)
this_delta = M_delta_full.transpose(dim, -1)
# we'l be applying momentum on delta, so multiply by (1-beta1) to
# get the total magnitude of the change right. The 2 changes below
# are the final changes, the only 2 we make in this loop that have
# side effects.
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
# there is no momentum on Q.
Q.add_(Q_delta, alpha=-meta_lr)
proj_grad.zero_()
def _zero_exp_avg_sq(self,
state: dict) -> None:
@ -516,25 +345,12 @@ class LearnedGradient(Optimizer):
state["scalar_exp_avg_sq"].zero_()
state["zero_step"] = state["step"]
def _diagonalize_lrs(self,
group: dict,
p: Tensor,
state: dict) -> None:
def _update_lrs(self,
group: dict,
p: Tensor,
state: dict) -> None:
"""
Diagonalizes the learning-rate matrices state["Q_{dim}"], and also
applies the min- and max- singular value constraints. This Q
is applied to the parameter matrix M such that we interpret
M == N Q.
By "diagonalize", we mean to left-miltiply Q by an orthogonal matrix,
Q := U Q,
so that the covariance of gradients w.r.t. N, measured
over the 2nd index of N while treating its 1st index as a batch index,
is diagonal.
Before doing this, also limits the singular values of Q so that the
max,min eig are lr_min_eig,lr_max_eig, and then renormalizing so the
mean is 1.0 and limiting again.
Computes learning-rate matrices Q for each dim of this tensor.
Args:
group: dict to look up configuration values
p: parameter matrix that we are updating. The learning rate matrices
@ -542,158 +358,87 @@ class LearnedGradient(Optimizer):
them.
state: state dict for the current parameter
"""
# meta_lr is the learning rate for the learning rate matrix. Including the factor
# (group["lr_est_period"] ** 0.5) is intended to make it approximately invariant
# to the lr_est_period (close to convergence).
lr_mat_min = group["lr_mat_min"]
lr_mat_max = group["lr_mat_max"]
ndim = p.ndim
numel = p.numel()
for dim in range(ndim):
size = p.shape[dim]
if size == 1:
continue
param_cov = state[f"param_cov_{dim}"]
U, S, V = _svd(param_cov)
rank = numel // size
rms = self._smooth_param_rms(group, S.sqrt(), rank)
if random.random() < 0.05:
logging.info(f"Shape={tuple(p.shape)}, dim={dim}, size={size}, rms={rms[::10]}")
Q = state[f"Q_{dim}"]
Q_orig = Q.clone() # TEMP
for i in range(4):
try:
self._diagonalize_lr(group, p, dim, state)
break # break from the loop over i: nothing was raised, so it succeeded.
except:
logging.warning(f"self._diagonalize_lr raised: i={i}, Q.sum()={Q.sum()}: likely "
f"error in SVD. Will try again after random change.")
# Likely torch SVD failed, its implementation for GPU has bugs for some torch versions,
# e.g. 1.7.1. Left-multiply Q by an arbitrary orthogonal matrix, and we'll try again.
# (Q is going to be diagonalized on that axis anyway, so this shouldn't meaningfully
# affect the output).
U, S, V = torch.randn_like(Q_orig).svd()
# U is a random orthogonal matrix.
Q[:] = torch.matmul(U, Q_orig)
Q[:] = (U * rms).t()
if True:
# This block does the actual diagonalization.
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
# the -1 represents all other tensor dims treated as a batch dimension.
# M_grad is the same shape as M. We could write a pseudo-loss as
# loss = tr(M_grad^T M)
# Because we can decompose M as M == N Q, we can write:
# loss = tr(M_grad^T N Q) = tr(Q M_grad^T N),
# so we can write this at tr(N_grad^T N),
# where N_grad == (Q M_grad^T)^T = M_grad Q^T.
# Now,
# grad_cov == M_grad^T M_grad,
# decaying-averaged over minibatches; this is of shape (size,size).
# Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N
# (which is what we want to diagonalize), as::
# N_grad_cov = N_grad^T N_grad
# = Q M_grad^T M_grad Q^T
# = Q grad_cov Q^T
# (note: this makes sense because the 1st index of Q is the diagonalized
# index).
def dispersion(v):
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
grad_cov = state[f"grad_cov_{dim}"]
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
U, S, V = _svd(N_grad_cov)
if random.random() < 0.1:
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
f"changed from {dispersion(N_grad_cov.diag())} to {dispersion(S)}")
# N_grad_cov is SPD, so
# N_grad_cov = U S U^T.
# Now, we can diagonalize N_grad_cov with:
# U^T N_grad_cov U == S.
# N_grad_cov is a sum of N_grad^T N_grad.
# We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal.
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
#
# which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted
# as tr(hat_N_grad hat_N), where:
# hat_N_grad = N_grad U
# hat_N = N U
# (hat_N means \hat{N}, or N with a hat on it).
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
# hat_N will be diagonalized. We also modify Q to hat_Q when
# we modify hat_N, to keep the product M unchanged:
# M = N Q = N U U^T Q = hat_N hat_Q
# This can be done by setting
# hat_Q := U^T Q (eq.10)
#
# This is the only thing we have to do, as N is implicit
# and not materialized at this point.
Q[:] = torch.matmul(U.t(), Q)
def _diagonalize_lr(self,
group: dict,
p: Tensor,
dim: int,
state: dict) -> None:
"""
This piece of code was broken out of _diagonalize_lrs so that we can more easily
recover from errors in torch SVD, which can be quite common in some torch versions;
e.g. in torch.1.7.1+cu101, CUDA SVD generates NaNs for some very reasonably-sized
inputs, like I have a 200 by 200 "bad case", whose min and max elements are
-0.2766 and 0.3444 respectively, that generates NaN's in its U and V factors.
Having closely-spaced singular values seems to present a problem.
"""
lr_mat_min = group["lr_mat_min"]
lr_mat_max = group["lr_mat_max"]
Q = state[f"Q_{dim}"]
if True:
# This block limits the singular values of Q.
U,S,V = Q.svd()
S_new = S.clamp(lr_mat_min, lr_mat_max)
S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0
S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more.
# Reconstruct Q with the modified S.
Q[:] = torch.matmul(U * S_new, V.t())
if random.random() < 0.03:
subsample = max(1, S.numel() // 20)
logging.info(f"shape={tuple(p.shape)}, dim={dim}, modified S from {S[::subsample]} to {S_new[::subsample]}")
if True:
# This block does the actual diagonalization.
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
# the -1 represents all other tensor dims treated as a batch dimension.
# M_grad is the same shape as M. We could write a pseudo-loss as
# loss = tr(M_grad^T M)
# Because we can decompose M as M == N Q, we can write:
# loss = tr(M_grad^T N Q) = tr(Q M_grad^T N),
# so we can write this at tr(N_grad^T N),
# where N_grad == (Q M_grad^T)^T = M_grad Q^T.
# Now,
# grad_cov == M_grad^T M_grad,
# decaying-averaged over minibatches; this is of shape (size,size).
# Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N
# (which is what we want to diagonalize), as::
# N_grad_cov = N_grad^T N_grad
# = Q M_grad^T M_grad Q^T
# = Q grad_cov Q^T
# (note: this makes sense because the 1st index of Q is the diagonalized
# index).
def dispersion(v):
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
grad_cov = state[f"grad_cov_{dim}"]
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
U, S, V = N_grad_cov.svd()
if random.random() < 0.1:
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
f"changed from {dispersion(N_grad_cov.diag())} to {dispersion(S)}")
# N_grad_cov is SPD, so
# N_grad_cov = U S U^T.
# Now, we can diagonalize N_grad_cov with:
# U^T N_grad_cov U == S.
# N_grad_cov is a sum of N_grad^T N_grad.
# We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal.
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
#
# which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted
# as tr(hat_N_grad hat_N), where:
# hat_N_grad = N_grad U
# hat_N = N U
# (hat_N means \hat{N}, or N with a hat on it).
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
# hat_N will be diagonalized. We also modify Q to hat_Q when
# we modify hat_N, to keep the product M unchanged:
# M = N Q = N U U^T Q = hat_N hat_Q
# This can be done by setting
# hat_Q := U^T Q (eq.10)
#
# This is the only thing we have to do, as N is implicit
# and not materialized at this point.
Q[:] = torch.matmul(U.t(), Q)
s = Q[:].sum()
if not s == s: # if a NaN was generated:
raise RuntimeError('Likely SVD failure')
def _get_matrix_var(self,
x: Tensor,
eps: float) -> Tensor:
"""
Estimate a variance of elements of x, which must be a tensor with exactly
2 diimensions, neither of which can have size == 1. This is done using
a structured representation, so the variance is a product over the 2
dimensions (otherwise, it wouldn't be possible to estimate a variance from
a single sample). Eventually we may do this via aggregation across
time as well; this approach seems to give a high variance estimate even
for largish dimension, e.g. if x.shape == (100,100) the 95th percentile
is from about 0.6 to 1.6.
Args:
x: tensor to estimate the variance of elements of. Require x.ndim == 2,
neither size must be 1.
eps: Epsilon to avoid zero variance: epsilon is dimensionally a rms value.
"""
assert x.ndim == 2 and 1 not in list(x.shape) # TODO: remove?
var0 = (x**2).mean(dim=1, keepdim=True) + eps*eps
var1 = ((x / var0.sqrt()) ** 2).mean(dim=0, keepdim=True) + eps*eps
# estimate var0 one more time.
var0 = ((x / var1.sqrt()) ** 2).mean(dim=1, keepdim=True) + eps*eps
return var1 * var0
def _step(self,
@ -725,12 +470,9 @@ class LearnedGradient(Optimizer):
# grad_cov_{dim}, which are for periodically updating the
# learning-rate matrices.
size = grad.shape[dim]
# accumulate some stats for learning the projections
proj_grad = state[f"proj_grad_{dim}"]
grad_cov = state[f"grad_cov_{dim}"]
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
proj_grad.add_(torch.matmul(this_m.t(), this_g))
# could perhaps accumulate grad_cov less frequently; it's only
# needed when we rediagonalize which is not that common.
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
@ -823,7 +565,8 @@ class LearnedGradient(Optimizer):
# direction we want to change canonical to diagonalized index, so have
# to transpose.
Q = Q.t()
# TODO: could possibly somehow force the output format to be unchanged.
# TODO: could possibly somehow force the output memory format to be
# unchanged.
x = x.transpose(-1, dim)
x = torch.matmul(x, Q)
x = x.transpose(-1, dim)
@ -1059,29 +802,33 @@ class LearnedGradient(Optimizer):
def _smooth_param_diag_var(self,
param_diag: Tensor,
param_pow: float,
param_rel_eps: float,
param_rel_max: float,
param_reverse_cutoff: float) -> Tensor:
def _smooth_param_rms(self,
group: dict,
rms: Tensor,
rank: int) -> Tensor:
"""
Applies the smoothing formula to the eigenvalues of the parameter covariance
tensor.
(Actually when we use this function in _get_param_cov, they won't exactly
be the eigenvalues because we diagonalize relative to the grad_cov,
but they will be parameter covariances in certain directions).
"""
# normalize so mean = 1.0
param_diag = param_diag / (param_diag.mean() + 1.0e-20)
# use 1/(1/x + 1/y) to apply soft-max of param_diag with param_rel_max.
# use param_rel_eps here to prevent division by zero.
ans = 1. / (1. / (param_diag + param_rel_eps) + 1. / param_rel_max)
Smooths and normalizes to mean=1 a tensor of diagonalized parameter rms values for one dimension
of a tensor. This will be used to construct a learning-rate matrix.
# 2nd factor below becomes a 1/param_diag type of function when
# param_diag >> param_reverse_cutoff.
ans = ans * (param_reverse_cutoff / (param_diag + param_reverse_cutoff))
return ans ** param_pow
Args:
group: dict for configuration values
rms: a tensor with one dimension, representing diagonalized parameter
root-mean-square values.
rank: the assumed rank of the covariance matrix from which rms was derived,
used to decide how much smoothing to apply. This is actually the
minimal rank, if the parameter matrix were stay fixed during training,
but still relevant to know how robust the parameter covariance estimate is.
"""
smooth0 = group["param_rms_smooth0"]
smooth1 = group["param_rms_smooth1"]
eps = group["eps"]
size, = rms.shape
smooth = (smooth0 +
(smooth1 - smooth0) * size / (size + rank))
mean = rms.mean()
rms += eps + smooth * mean
new_mean = (eps + (smooth + 1) * mean)
return rms / new_mean
def _estimate_proj(self,
@ -1241,7 +988,7 @@ class LearnedGradient(Optimizer):
param_freqs = param_freqs.tolist()
param_periods = param_periods.tolist()
logging.info(f"LearnedGradient._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}")
logging.info(f"PrAdam._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}")
print_info = random.random() < 0.01
i = 0
for p in group["params"]:
@ -1256,6 +1003,27 @@ class LearnedGradient(Optimizer):
# Do nothing more, as yet.
def _svd(x: Tensor):
# Wrapper for torch svd that catches errors and re-tries.
randU = None
for i in range(4):
try:
U, S, V = x.svd()
s = U.sum()
if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01):
raise RuntimeError(f"svd failed (or random test), sum={s}")
if randU is not None:
U = torch.matmul(randU.t(), U)
return U, S, V # success
except:
logging.warning(f"svd failed: i={i}, x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely "
f"error in SVD. Will try again after random change.")
U, S, V = torch.randn_like(x).svd()
x = torch.matmul(U, x)
if randU is None:
randU = U
else:
randU = torch.matmul(U, randU)
@ -1902,8 +1670,8 @@ def _test_eve_cain():
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
elif iter == 2: optim = LearnedGradient(m.parameters(), lr=0.03)
elif iter == 3: optim = LearnedGradient(m.parameters(), lr=0.03)
elif iter == 2: optim = PrAdam(m.parameters(), lr=0.03)
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03)
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
start = timeit.default_timer()
@ -1965,11 +1733,18 @@ def _test_eve_cain():
logging.info("Normalized output row magnitudes log-stddev: {stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log())}")
def _test_svd():
device = 'cuda'
for i in range(1000):
X = torch.eye(100, device=device) + 1.0e-08 * torch.randn(100, 100, device=device)
U, S, V = _svd(X)
X2 = torch.matmul(U*S, V.t())
assert torch.allclose(X2, X, atol=0.001)
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
#_test_svd()
_test_eve_cain()
#_test_eden()