mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement new version of learning method. Does more complete diagonalization of grads than the previous methods.
This commit is contained in:
parent
a9edecd32c
commit
6810849058
@ -24,10 +24,10 @@ from torch.optim import Optimizer
|
||||
from icefall import diagnostics # only for testing code
|
||||
import logging
|
||||
|
||||
class LearnedGradient(Optimizer):
|
||||
class PrAdam(Optimizer):
|
||||
"""
|
||||
Implements 'Learned Gradient' update. It's special factorization of the
|
||||
parameter matrix..
|
||||
Implements 'Projected Adam', a variant of Adam with projections for each
|
||||
dim of each tensor.
|
||||
|
||||
|
||||
Args:
|
||||
@ -35,63 +35,70 @@ class LearnedGradient(Optimizer):
|
||||
lr: The learning rate. We will typically use a learning rate schedule that starts
|
||||
at 0.03 and decreases over time, i.e. much higher than other common
|
||||
optimizers.
|
||||
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad only for
|
||||
scalar tensors. Must satisfy 0 < beta <= beta2 < 1.
|
||||
size_lr_scale: A scaling factor on the learning rate, that we use to update the
|
||||
scale of each parameter tensor. If each parameter were decomposed
|
||||
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, size_lr_scale
|
||||
is the scaling factor on the learning rate of p_scale.
|
||||
meta_lr_scale: The learning rate for the learning-rate matrix, this is a scale on
|
||||
the regular learning rate.
|
||||
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad only for
|
||||
scalar tensors. Must satisfy 0 < beta <= beta2 < 1.
|
||||
param_pow: Power on the parameter covariance matrix, 1.0 means learn proportional
|
||||
to parameter rms (1.0 will be too much, should be between 0 and 1.)
|
||||
param_rms_smooth0: Limiting value of smoothing proportion for parameter matrix, as
|
||||
assumed rank of param covariance [==product of sizes on the other
|
||||
tensor dims] approaches 0.
|
||||
param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
param covariance equals the dimension of the covaraince matrix.
|
||||
param_rms_smooth{0,1} determine the smoothing proportions for other
|
||||
conditions.
|
||||
param_cov_freshness: Constant that determines how "recent" the parameter covariance
|
||||
matrix is. 1.0 corresponds to flat average over time; 2.0
|
||||
would mean weighting with a quadratic function, etc.
|
||||
eps: An epsilon to prevent division by zero
|
||||
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
||||
learning the scale on the parameters (we'll keep it >= this size)
|
||||
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
||||
learning the scale on the parameters (we'll keep it <= this size)
|
||||
lr_mat_min: Minimum singular value of learning-rate matrices Q.
|
||||
lr_mat_max: Maximum singular value of learning-rate matrices Q.
|
||||
lr_est_period: The periodicity, in steps, with which we update the learning-rate matrices.
|
||||
diagonalize_period: How often we re-diagonalize the gradient covariance; this
|
||||
gets multiplied by lr_est_period, i.e. we diagonalize less frequently
|
||||
than we update the learning-rate matrixes.
|
||||
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
||||
of the parameter tensor. This is provided to save a little time.
|
||||
lr_update_period: The periodicity, in steps, with which we update the learning-rate matrices.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=3e-02,
|
||||
size_lr_scale=0.1,
|
||||
meta_lr_scale=0.0,
|
||||
betas=(0.9, 0.98),
|
||||
size_lr_scale=0.1,
|
||||
param_pow=0.75,
|
||||
param_rms_smooth0=0.75,
|
||||
param_rms_smooth1=0.25,
|
||||
param_cov_freshness=1.0,
|
||||
eps=1.0e-08,
|
||||
size_update_period=1,
|
||||
param_min_rms=1.0e-05,
|
||||
param_max_rms=2.0,
|
||||
lr_mat_min=0.01,
|
||||
lr_mat_max=10.0,
|
||||
lr_est_period=2,
|
||||
diagonalize_period=4,
|
||||
size_update_period=4,
|
||||
lr_update_period=20,
|
||||
):
|
||||
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
size_lr_scale=size_lr_scale,
|
||||
meta_lr_scale=meta_lr_scale,
|
||||
param_pow=param_pow,
|
||||
param_rms_smooth0=param_rms_smooth0,
|
||||
param_rms_smooth1=param_rms_smooth1,
|
||||
param_cov_freshness=param_cov_freshness,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
size_update_period=size_update_period,
|
||||
param_min_rms=param_min_rms,
|
||||
param_max_rms=param_max_rms,
|
||||
lr_mat_min=lr_mat_min,
|
||||
lr_mat_max=lr_mat_max,
|
||||
lr_est_period=lr_est_period,
|
||||
diagonalize_period=diagonalize_period,
|
||||
size_update_period=size_update_period,
|
||||
lr_update_period=lr_update_period,
|
||||
)
|
||||
|
||||
super(LearnedGradient, self).__init__(params, defaults)
|
||||
super(PrAdam, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(LearnedGradient, self).__setstate__(state)
|
||||
super(PrAdam, self).__setstate__(state)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -111,14 +118,11 @@ class LearnedGradient(Optimizer):
|
||||
lr = group["lr"]
|
||||
size_lr = lr * group["size_lr_scale"]
|
||||
beta1, beta2 = group["betas"]
|
||||
lr_mat_min = group["lr_mat_min"]
|
||||
lr_mat_max = group["lr_mat_max"]
|
||||
eps = group["eps"]
|
||||
size_update_period = group["size_update_period"]
|
||||
param_min_rms = group["param_min_rms"]
|
||||
param_max_rms = group["param_max_rms"]
|
||||
lr_est_period = group["lr_est_period"]
|
||||
diagonalize_period = group["diagonalize_period"]
|
||||
lr_update_period = group["lr_update_period"]
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
@ -128,7 +132,7 @@ class LearnedGradient(Optimizer):
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"LearnedGradient optimizer does not support sparse gradients"
|
||||
"PrAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
state = self.state[p]
|
||||
|
||||
@ -177,29 +181,14 @@ class LearnedGradient(Optimizer):
|
||||
continue
|
||||
|
||||
|
||||
# Q_{dim}, which we will call q in this comment, will be the learning-rate matrix for this
|
||||
# dimension, times an orthogonal matrix that we'll periodically
|
||||
# re-estimate when we "rediagonalize" (this diagonalizes
|
||||
# the gradient covariance).
|
||||
# If our parameter matrix M is of shape (..., size),
|
||||
# we'll view M as being M == torch.matmul(N, Q)
|
||||
# so p can be interpreted as having shape
|
||||
# (size, size), interpreted as being indexed [diagonalized_coordinate, canonical_coordinate]
|
||||
# by "diagonalize" we mean that we diagonalize the gradient covariance.
|
||||
# Q_{dim} is be the learning-rate matrix for this
|
||||
# dimension, a matrix of indexed [diagonalized_coordinate, canonical_coordinate].
|
||||
# this is used twice in the update step, once transposed and once without transpose.
|
||||
state[f"Q_{dim}"] = torch.eye(size, **kwargs)
|
||||
|
||||
# proj_grad_{dim} is the gradient w.r.t. `proj`,
|
||||
# which is a matrix we introduce to the right of Q, i.e.
|
||||
# instead of M == torch.matmul(N, Q) we view it as
|
||||
# M == torch.matmul(torch.matmul(N, torch.matmul(Q, proj)))
|
||||
# or equivalently M == torch.matmul(M_underlying, proj)
|
||||
# where M_underlying is numerically equal to M, and proj is numerically
|
||||
# equal to I, but they are viewed as separate variables.
|
||||
#
|
||||
# proj_grad_{dim} will be set to the sum of
|
||||
# torch.matmul(M^T, M_grad) over a number of steps (`lr_est_period` steps),
|
||||
# and will be zeroed after we call self_update_lrs().
|
||||
state[f"proj_grad_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||
# param_cov_{dim} is the averaged-over-time gradient of parameters on this dimension, treating
|
||||
# all other dims as a batch axis.
|
||||
state[f"param_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||
|
||||
# grad_cov_{dim} is the covariance of gradients on this axis (without
|
||||
# any co-ordinate changes), treating all other axes as as a batch axis.
|
||||
@ -213,14 +202,6 @@ class LearnedGradient(Optimizer):
|
||||
# instead of just using a temporary and smoothing the scalar factor.
|
||||
state[f"grad_cov_{dim}"] = torch.zeros(size, size, **kwargs)
|
||||
|
||||
# proj_grad_var_{dim} is the variance of gradients of the thing we're
|
||||
# going to multiply proj_{dim} by (between it and the "underlying
|
||||
# parameter matrix"), aggregated over each of its 2 dimensions
|
||||
# and decayed over time with beta1 each time we do a step on proj_{dim},
|
||||
# e.g. every lr_est_period steps. It's a factored representation
|
||||
# of something like exp_avg_sq, used for estimating proj_{dim}; we avoid
|
||||
# storing an exp_avg_sq for proj_{dim}, to save GPU memory.
|
||||
state[f"proj_grad_var_{dim}"] = torch.ones(2, size, **kwargs)
|
||||
|
||||
|
||||
step = state["step"]
|
||||
@ -250,10 +231,9 @@ class LearnedGradient(Optimizer):
|
||||
# parameter rms. Updates delta.
|
||||
self._step_scalar(beta1, beta2, eps, lr, p, grad, state)
|
||||
else:
|
||||
if step % lr_est_period == 0 and step > 0:
|
||||
if step % lr_update_period == 0 and step > 0:
|
||||
self._accum_param_covs(group, p, state)
|
||||
self._update_lrs(group, p, state)
|
||||
if step % (lr_est_period * diagonalize_period) == 0 and step > 0:
|
||||
self._diagonalize_lrs(group, p, state)
|
||||
self._zero_exp_avg_sq(state)
|
||||
self._step(group, p, grad, state)
|
||||
p.add_(delta)
|
||||
@ -326,31 +306,19 @@ class LearnedGradient(Optimizer):
|
||||
# the factor of (1-beta1) relates to momentum.
|
||||
delta.add_(p, alpha=(1-beta1) * scale_step)
|
||||
|
||||
|
||||
def _update_lrs(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
def _accum_param_covs(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
"""
|
||||
Updates the learning-rate matrices state["Q_{dim}"].
|
||||
|
||||
Args:
|
||||
group: dict to look up configuration values
|
||||
p: parameter matrix that we are updating. The learning rate matrices
|
||||
are actually factors of p, so p itself will change when we change
|
||||
them.
|
||||
state: state dict for the current parameter
|
||||
Updates state[f"param_cov_{dim}"] for each dim of tensor p.
|
||||
"""
|
||||
# meta_lr is the learning rate for the learning rate matrix. Including the factor
|
||||
# (group["lr_est_period"] ** 0.5) is intended to make it approximately invariant
|
||||
# to the lr_est_period (close to convergence).
|
||||
meta_lr = group["lr"] * group["meta_lr_scale"] * (group["lr_est_period"] ** 0.5)
|
||||
beta1 = group["betas"][0]
|
||||
eps = group["eps"]
|
||||
delta = state["delta"]
|
||||
ndim = p.ndim
|
||||
|
||||
for dim in range(ndim):
|
||||
lr_update_period = group["lr_update_period"]
|
||||
step = state["step"]
|
||||
this_weight = (group["param_cov_freshness"] * lr_update_period /
|
||||
(step + lr_update_period))
|
||||
for dim in range(p.ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
@ -359,153 +327,14 @@ class LearnedGradient(Optimizer):
|
||||
# all other dimensions of the tensor.
|
||||
M_full = p.transpose(dim, -1)
|
||||
M = M_full.reshape(-1, size)
|
||||
this_param_cov = torch.matmul(M.t(), M)
|
||||
# normalize scale of this param_cov, in casre parameter scale
|
||||
# changes significantly during training, which would cause some
|
||||
# parts of the training timeline to be more highly weighted.
|
||||
this_param_cov /= (this_param_cov.diag().mean() + eps)
|
||||
state[f"param_cov_{dim}"].mul_(1-this_weight).add_(this_param_cov,
|
||||
alpha=this_weight)
|
||||
|
||||
# proj_grad is a summed gradient, of shape (size, size),
|
||||
# indexed [old_param_idx][new_param_idx] (the point is,
|
||||
# param_idx corresponds to the index of the current M, new_param_idx
|
||||
# corresponds to the index of the M after a small change).
|
||||
#
|
||||
# Suppose M is our parameter matrix p reshaped to (-1, size) with
|
||||
# the -1 being all other dims of the tensor treated as a batch dim
|
||||
# and "size" is the size of whatever dimension we are working on
|
||||
# right now. And suppose M_grad is the loss derivative w.r.t. M
|
||||
# (for compatibility with things like Torch, we use a convention
|
||||
# where all loss derivatives/grads have the same shape as the things they
|
||||
# are derivative with respect to).
|
||||
#
|
||||
# proj_grad is the accumulated value, on the last few steps, of M M_grad^T,
|
||||
# which will be of shape (size, size). For purposes of our update, if we view
|
||||
# the parameter matrix M as (M_underlying proj) with
|
||||
# proj being numerically equal to I but treated as a variable, i.e. as
|
||||
# M = matmul(M_underlying, proj), then instead of the loss being tr(M_grad^T M), we can write
|
||||
# it as tr(M_grad^T M_underlying proj), which can also be written as tr(proj_grad^T proj)
|
||||
# where proj_grad == (M_grad^T M_underlying)^T == M_underlying^T M_grad == M^T M_grad
|
||||
#
|
||||
# Now, proj_grad is convenient to compute, being independent of Q,
|
||||
# but it's not in a very normalized
|
||||
# space; we want to normalize it before doing gradient descent on it.
|
||||
# We're going to view M as
|
||||
# M == N Q,
|
||||
# where Q is our Q_{dim} "learning-rate" matrix, of shape (size, size) indexed
|
||||
# [diagonalized_index, canonical_index] (see the code), and N == M Q^{-1}.
|
||||
# It's likely better numerically to do gradient descent on a `proj2` matrix
|
||||
# that's between N and Q,
|
||||
# i.e.:
|
||||
# M == N proj2 Q
|
||||
# We want to obtain the derivative w.r.t. proj2.
|
||||
# So the linearized pseudo-loss is
|
||||
# tr(M_grad^T M) == tr(M_grad^T N proj2 Q) == tr(Q M_grad^T N proj2),
|
||||
# which we can view as tr(proj2_grad^T proj2), where
|
||||
# proj2_grad == (Q M_grad^T N)^T == N^T M_grad Q^T == Q^{-T} M^T M_grad Q^T
|
||||
# proj2_grad = Q^{-T} proj_grad Q^T (eq.1)
|
||||
#
|
||||
# After computing proj2_grad we compute/update a factorized representation of
|
||||
# its variance, factorized over its 2 axes, and do an Adam-type update
|
||||
# for it, except without momentum at this level (momentum will be applied
|
||||
# later). This gives us proj2_delta, according to the
|
||||
# equation:
|
||||
# proj2 = I + proj2_delta
|
||||
# (again: proj2_delta is derived from proj2_grad using an Adam-style update).
|
||||
#
|
||||
# We now have to update the "learning-rate matrix" Q using proj2, as in
|
||||
# Q := proj2 Q
|
||||
# Q := (proj2_delta + I) Q, or:
|
||||
# Q += proj2_delta Q (eq.2)
|
||||
#
|
||||
# We also need to update M itself; this is easiest done by computing proj_delta
|
||||
# (which is the difference of proj from I). The relationship between proj and proj2
|
||||
# is given by equating
|
||||
# M == N proj2 Q == N Q proj,
|
||||
# so proj2 Q == Q proj
|
||||
# proj == Q^{-1} proj2 Q
|
||||
# and subtracting out the constant "I" part,
|
||||
# proj_delta == Q^{-1} proj2_delta Q (eq.3)
|
||||
#
|
||||
# and then because we view the parameter matrix as
|
||||
# "M proj",
|
||||
# with proj being numerically equal to I but learned here,
|
||||
# it becomes M (I + proj_delta) = M + M proj_delta.
|
||||
# So the update to M is:
|
||||
# M += M proj_delta
|
||||
# or equivalently:
|
||||
# M_delta = M proj_delta (eq.5)
|
||||
proj_grad = state[f"proj_grad_{dim}"]
|
||||
Q = state[f"Q_{dim}"]
|
||||
|
||||
# delta_scale < 1 will make it update the learning rates faster than it otherwise would,
|
||||
# as we'll reach equilbrium with M less rapidly.
|
||||
delta_scale=0.5
|
||||
|
||||
if True:
|
||||
# Simpler update, do it to the right of Q.
|
||||
# symmetrize, otherwise there is a bad interaction with the regular update.
|
||||
proj_grad[:] = proj_grad + proj_grad.t()
|
||||
proj_grad_var = self._get_matrix_var(proj_grad, group["eps"])
|
||||
denom = proj_grad_var.sqrt()
|
||||
proj_delta = proj_grad / denom
|
||||
Q_delta = torch.matmul(Q, proj_delta)
|
||||
M_delta = torch.matmul(M, proj_delta)
|
||||
M_delta_full = M_delta.reshape(M_full.shape)
|
||||
this_delta = M_delta_full.transpose(dim, -1)
|
||||
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
|
||||
# there is no momentum on Q.
|
||||
Q.add_(Q_delta, alpha=-meta_lr)
|
||||
proj_grad.zero_()
|
||||
continue
|
||||
|
||||
|
||||
# Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
||||
# torch.solve(A, B) returns A^{-1} B.
|
||||
try:
|
||||
# not all versions of Torch have this.
|
||||
Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad)
|
||||
except:
|
||||
# torch.solve has args the other way round and also returns LU.
|
||||
Q_invt_proj_grad, _ = torch.solve(proj_grad, Q.t())
|
||||
|
||||
# (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
|
||||
proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t())
|
||||
|
||||
# symmetrize, otherwise there is a bad interaction with the regular
|
||||
# update in self._step().
|
||||
proj2_grad = proj2_grad + proj2_grad.t()
|
||||
|
||||
# for now just estimate proj2_grad_var from proj2_grad; later, we
|
||||
# may try aggregating it across iterations. This is like exp_avg_sq
|
||||
# in Adam.
|
||||
proj2_grad_var = self._get_matrix_var(proj2_grad, group["eps"])
|
||||
|
||||
denom = proj2_grad_var.sqrt()
|
||||
|
||||
# we'll take into account the factor of -meta_lr * (1-beta1) later on;
|
||||
# actually proj2_delta is the negative of a parameter change.
|
||||
proj2_delta = proj2_grad / denom
|
||||
|
||||
# See (eq.2), Q += proj2_delta Q
|
||||
Q_delta = torch.matmul(proj2_delta, Q)
|
||||
|
||||
# See (eq.3), proj_delta = Q^{-1} proj2_delta Q = Q^{-1} Q_delta
|
||||
try:
|
||||
proj_delta = torch.linalg.solve(Q, Q_delta)
|
||||
except:
|
||||
# torch.solve has args the other way round and also returns LU.
|
||||
proj_delta, _ = torch.solve(Q_delta, Q)
|
||||
|
||||
M_delta = torch.matmul(M, proj_delta) # (eq.5), M_delta = M proj_delta
|
||||
|
||||
M_delta_full = M_delta.reshape(M_full.shape)
|
||||
|
||||
this_delta = M_delta_full.transpose(dim, -1)
|
||||
# we'l be applying momentum on delta, so multiply by (1-beta1) to
|
||||
# get the total magnitude of the change right. The 2 changes below
|
||||
# are the final changes, the only 2 we make in this loop that have
|
||||
# side effects.
|
||||
|
||||
delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1))
|
||||
# there is no momentum on Q.
|
||||
Q.add_(Q_delta, alpha=-meta_lr)
|
||||
|
||||
proj_grad.zero_()
|
||||
|
||||
def _zero_exp_avg_sq(self,
|
||||
state: dict) -> None:
|
||||
@ -516,25 +345,12 @@ class LearnedGradient(Optimizer):
|
||||
state["scalar_exp_avg_sq"].zero_()
|
||||
state["zero_step"] = state["step"]
|
||||
|
||||
def _diagonalize_lrs(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
def _update_lrs(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
"""
|
||||
Diagonalizes the learning-rate matrices state["Q_{dim}"], and also
|
||||
applies the min- and max- singular value constraints. This Q
|
||||
is applied to the parameter matrix M such that we interpret
|
||||
M == N Q.
|
||||
By "diagonalize", we mean to left-miltiply Q by an orthogonal matrix,
|
||||
Q := U Q,
|
||||
so that the covariance of gradients w.r.t. N, measured
|
||||
over the 2nd index of N while treating its 1st index as a batch index,
|
||||
is diagonal.
|
||||
|
||||
Before doing this, also limits the singular values of Q so that the
|
||||
max,min eig are lr_min_eig,lr_max_eig, and then renormalizing so the
|
||||
mean is 1.0 and limiting again.
|
||||
|
||||
Computes learning-rate matrices Q for each dim of this tensor.
|
||||
Args:
|
||||
group: dict to look up configuration values
|
||||
p: parameter matrix that we are updating. The learning rate matrices
|
||||
@ -542,158 +358,87 @@ class LearnedGradient(Optimizer):
|
||||
them.
|
||||
state: state dict for the current parameter
|
||||
"""
|
||||
# meta_lr is the learning rate for the learning rate matrix. Including the factor
|
||||
# (group["lr_est_period"] ** 0.5) is intended to make it approximately invariant
|
||||
# to the lr_est_period (close to convergence).
|
||||
|
||||
lr_mat_min = group["lr_mat_min"]
|
||||
lr_mat_max = group["lr_mat_max"]
|
||||
|
||||
ndim = p.ndim
|
||||
numel = p.numel()
|
||||
|
||||
for dim in range(ndim):
|
||||
size = p.shape[dim]
|
||||
if size == 1:
|
||||
continue
|
||||
param_cov = state[f"param_cov_{dim}"]
|
||||
U, S, V = _svd(param_cov)
|
||||
rank = numel // size
|
||||
rms = self._smooth_param_rms(group, S.sqrt(), rank)
|
||||
|
||||
if random.random() < 0.05:
|
||||
logging.info(f"Shape={tuple(p.shape)}, dim={dim}, size={size}, rms={rms[::10]}")
|
||||
|
||||
Q = state[f"Q_{dim}"]
|
||||
Q_orig = Q.clone() # TEMP
|
||||
for i in range(4):
|
||||
try:
|
||||
self._diagonalize_lr(group, p, dim, state)
|
||||
break # break from the loop over i: nothing was raised, so it succeeded.
|
||||
except:
|
||||
logging.warning(f"self._diagonalize_lr raised: i={i}, Q.sum()={Q.sum()}: likely "
|
||||
f"error in SVD. Will try again after random change.")
|
||||
# Likely torch SVD failed, its implementation for GPU has bugs for some torch versions,
|
||||
# e.g. 1.7.1. Left-multiply Q by an arbitrary orthogonal matrix, and we'll try again.
|
||||
# (Q is going to be diagonalized on that axis anyway, so this shouldn't meaningfully
|
||||
# affect the output).
|
||||
U, S, V = torch.randn_like(Q_orig).svd()
|
||||
# U is a random orthogonal matrix.
|
||||
Q[:] = torch.matmul(U, Q_orig)
|
||||
Q[:] = (U * rms).t()
|
||||
|
||||
if True:
|
||||
# This block does the actual diagonalization.
|
||||
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
|
||||
# the -1 represents all other tensor dims treated as a batch dimension.
|
||||
# M_grad is the same shape as M. We could write a pseudo-loss as
|
||||
# loss = tr(M_grad^T M)
|
||||
# Because we can decompose M as M == N Q, we can write:
|
||||
# loss = tr(M_grad^T N Q) = tr(Q M_grad^T N),
|
||||
# so we can write this at tr(N_grad^T N),
|
||||
# where N_grad == (Q M_grad^T)^T = M_grad Q^T.
|
||||
# Now,
|
||||
# grad_cov == M_grad^T M_grad,
|
||||
# decaying-averaged over minibatches; this is of shape (size,size).
|
||||
# Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N
|
||||
# (which is what we want to diagonalize), as::
|
||||
# N_grad_cov = N_grad^T N_grad
|
||||
# = Q M_grad^T M_grad Q^T
|
||||
# = Q grad_cov Q^T
|
||||
# (note: this makes sense because the 1st index of Q is the diagonalized
|
||||
# index).
|
||||
|
||||
def dispersion(v):
|
||||
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
|
||||
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
|
||||
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
|
||||
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
|
||||
U, S, V = _svd(N_grad_cov)
|
||||
if random.random() < 0.1:
|
||||
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
|
||||
f"changed from {dispersion(N_grad_cov.diag())} to {dispersion(S)}")
|
||||
|
||||
|
||||
# N_grad_cov is SPD, so
|
||||
# N_grad_cov = U S U^T.
|
||||
|
||||
# Now, we can diagonalize N_grad_cov with:
|
||||
# U^T N_grad_cov U == S.
|
||||
# N_grad_cov is a sum of N_grad^T N_grad.
|
||||
# We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal.
|
||||
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
|
||||
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
|
||||
#
|
||||
# which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted
|
||||
# as tr(hat_N_grad hat_N), where:
|
||||
# hat_N_grad = N_grad U
|
||||
# hat_N = N U
|
||||
# (hat_N means \hat{N}, or N with a hat on it).
|
||||
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
|
||||
# hat_N will be diagonalized. We also modify Q to hat_Q when
|
||||
# we modify hat_N, to keep the product M unchanged:
|
||||
# M = N Q = N U U^T Q = hat_N hat_Q
|
||||
# This can be done by setting
|
||||
# hat_Q := U^T Q (eq.10)
|
||||
#
|
||||
# This is the only thing we have to do, as N is implicit
|
||||
# and not materialized at this point.
|
||||
Q[:] = torch.matmul(U.t(), Q)
|
||||
|
||||
|
||||
|
||||
|
||||
def _diagonalize_lr(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
dim: int,
|
||||
state: dict) -> None:
|
||||
"""
|
||||
This piece of code was broken out of _diagonalize_lrs so that we can more easily
|
||||
recover from errors in torch SVD, which can be quite common in some torch versions;
|
||||
e.g. in torch.1.7.1+cu101, CUDA SVD generates NaNs for some very reasonably-sized
|
||||
inputs, like I have a 200 by 200 "bad case", whose min and max elements are
|
||||
-0.2766 and 0.3444 respectively, that generates NaN's in its U and V factors.
|
||||
Having closely-spaced singular values seems to present a problem.
|
||||
"""
|
||||
lr_mat_min = group["lr_mat_min"]
|
||||
lr_mat_max = group["lr_mat_max"]
|
||||
|
||||
Q = state[f"Q_{dim}"]
|
||||
if True:
|
||||
# This block limits the singular values of Q.
|
||||
U,S,V = Q.svd()
|
||||
S_new = S.clamp(lr_mat_min, lr_mat_max)
|
||||
S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0
|
||||
S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more.
|
||||
# Reconstruct Q with the modified S.
|
||||
Q[:] = torch.matmul(U * S_new, V.t())
|
||||
if random.random() < 0.03:
|
||||
subsample = max(1, S.numel() // 20)
|
||||
logging.info(f"shape={tuple(p.shape)}, dim={dim}, modified S from {S[::subsample]} to {S_new[::subsample]}")
|
||||
|
||||
if True:
|
||||
# This block does the actual diagonalization.
|
||||
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
|
||||
# the -1 represents all other tensor dims treated as a batch dimension.
|
||||
# M_grad is the same shape as M. We could write a pseudo-loss as
|
||||
# loss = tr(M_grad^T M)
|
||||
# Because we can decompose M as M == N Q, we can write:
|
||||
# loss = tr(M_grad^T N Q) = tr(Q M_grad^T N),
|
||||
# so we can write this at tr(N_grad^T N),
|
||||
# where N_grad == (Q M_grad^T)^T = M_grad Q^T.
|
||||
# Now,
|
||||
# grad_cov == M_grad^T M_grad,
|
||||
# decaying-averaged over minibatches; this is of shape (size,size).
|
||||
# Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N
|
||||
# (which is what we want to diagonalize), as::
|
||||
# N_grad_cov = N_grad^T N_grad
|
||||
# = Q M_grad^T M_grad Q^T
|
||||
# = Q grad_cov Q^T
|
||||
# (note: this makes sense because the 1st index of Q is the diagonalized
|
||||
# index).
|
||||
|
||||
def dispersion(v):
|
||||
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
|
||||
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
|
||||
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
|
||||
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
|
||||
U, S, V = N_grad_cov.svd()
|
||||
if random.random() < 0.1:
|
||||
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
|
||||
f"changed from {dispersion(N_grad_cov.diag())} to {dispersion(S)}")
|
||||
|
||||
# N_grad_cov is SPD, so
|
||||
# N_grad_cov = U S U^T.
|
||||
|
||||
# Now, we can diagonalize N_grad_cov with:
|
||||
# U^T N_grad_cov U == S.
|
||||
# N_grad_cov is a sum of N_grad^T N_grad.
|
||||
# We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal.
|
||||
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
|
||||
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
|
||||
#
|
||||
# which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted
|
||||
# as tr(hat_N_grad hat_N), where:
|
||||
# hat_N_grad = N_grad U
|
||||
# hat_N = N U
|
||||
# (hat_N means \hat{N}, or N with a hat on it).
|
||||
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
|
||||
# hat_N will be diagonalized. We also modify Q to hat_Q when
|
||||
# we modify hat_N, to keep the product M unchanged:
|
||||
# M = N Q = N U U^T Q = hat_N hat_Q
|
||||
# This can be done by setting
|
||||
# hat_Q := U^T Q (eq.10)
|
||||
#
|
||||
# This is the only thing we have to do, as N is implicit
|
||||
# and not materialized at this point.
|
||||
Q[:] = torch.matmul(U.t(), Q)
|
||||
|
||||
s = Q[:].sum()
|
||||
if not s == s: # if a NaN was generated:
|
||||
raise RuntimeError('Likely SVD failure')
|
||||
|
||||
|
||||
def _get_matrix_var(self,
|
||||
x: Tensor,
|
||||
eps: float) -> Tensor:
|
||||
"""
|
||||
Estimate a variance of elements of x, which must be a tensor with exactly
|
||||
2 diimensions, neither of which can have size == 1. This is done using
|
||||
a structured representation, so the variance is a product over the 2
|
||||
dimensions (otherwise, it wouldn't be possible to estimate a variance from
|
||||
a single sample). Eventually we may do this via aggregation across
|
||||
time as well; this approach seems to give a high variance estimate even
|
||||
for largish dimension, e.g. if x.shape == (100,100) the 95th percentile
|
||||
is from about 0.6 to 1.6.
|
||||
|
||||
Args:
|
||||
x: tensor to estimate the variance of elements of. Require x.ndim == 2,
|
||||
neither size must be 1.
|
||||
eps: Epsilon to avoid zero variance: epsilon is dimensionally a rms value.
|
||||
"""
|
||||
assert x.ndim == 2 and 1 not in list(x.shape) # TODO: remove?
|
||||
|
||||
var0 = (x**2).mean(dim=1, keepdim=True) + eps*eps
|
||||
var1 = ((x / var0.sqrt()) ** 2).mean(dim=0, keepdim=True) + eps*eps
|
||||
# estimate var0 one more time.
|
||||
var0 = ((x / var1.sqrt()) ** 2).mean(dim=1, keepdim=True) + eps*eps
|
||||
return var1 * var0
|
||||
|
||||
|
||||
|
||||
def _step(self,
|
||||
@ -725,12 +470,9 @@ class LearnedGradient(Optimizer):
|
||||
# grad_cov_{dim}, which are for periodically updating the
|
||||
# learning-rate matrices.
|
||||
size = grad.shape[dim]
|
||||
# accumulate some stats for learning the projections
|
||||
proj_grad = state[f"proj_grad_{dim}"]
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
this_m = p.transpose(-1, dim).reshape(-1, size) # parameter matrix M
|
||||
this_g = grad.transpose(-1, dim).reshape(-1, size) # M_grad
|
||||
proj_grad.add_(torch.matmul(this_m.t(), this_g))
|
||||
# could perhaps accumulate grad_cov less frequently; it's only
|
||||
# needed when we rediagonalize which is not that common.
|
||||
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
|
||||
@ -823,7 +565,8 @@ class LearnedGradient(Optimizer):
|
||||
# direction we want to change canonical to diagonalized index, so have
|
||||
# to transpose.
|
||||
Q = Q.t()
|
||||
# TODO: could possibly somehow force the output format to be unchanged.
|
||||
# TODO: could possibly somehow force the output memory format to be
|
||||
# unchanged.
|
||||
x = x.transpose(-1, dim)
|
||||
x = torch.matmul(x, Q)
|
||||
x = x.transpose(-1, dim)
|
||||
@ -1059,29 +802,33 @@ class LearnedGradient(Optimizer):
|
||||
|
||||
|
||||
|
||||
def _smooth_param_diag_var(self,
|
||||
param_diag: Tensor,
|
||||
param_pow: float,
|
||||
param_rel_eps: float,
|
||||
param_rel_max: float,
|
||||
param_reverse_cutoff: float) -> Tensor:
|
||||
def _smooth_param_rms(self,
|
||||
group: dict,
|
||||
rms: Tensor,
|
||||
rank: int) -> Tensor:
|
||||
"""
|
||||
Applies the smoothing formula to the eigenvalues of the parameter covariance
|
||||
tensor.
|
||||
(Actually when we use this function in _get_param_cov, they won't exactly
|
||||
be the eigenvalues because we diagonalize relative to the grad_cov,
|
||||
but they will be parameter covariances in certain directions).
|
||||
"""
|
||||
# normalize so mean = 1.0
|
||||
param_diag = param_diag / (param_diag.mean() + 1.0e-20)
|
||||
# use 1/(1/x + 1/y) to apply soft-max of param_diag with param_rel_max.
|
||||
# use param_rel_eps here to prevent division by zero.
|
||||
ans = 1. / (1. / (param_diag + param_rel_eps) + 1. / param_rel_max)
|
||||
Smooths and normalizes to mean=1 a tensor of diagonalized parameter rms values for one dimension
|
||||
of a tensor. This will be used to construct a learning-rate matrix.
|
||||
|
||||
# 2nd factor below becomes a 1/param_diag type of function when
|
||||
# param_diag >> param_reverse_cutoff.
|
||||
ans = ans * (param_reverse_cutoff / (param_diag + param_reverse_cutoff))
|
||||
return ans ** param_pow
|
||||
Args:
|
||||
group: dict for configuration values
|
||||
rms: a tensor with one dimension, representing diagonalized parameter
|
||||
root-mean-square values.
|
||||
rank: the assumed rank of the covariance matrix from which rms was derived,
|
||||
used to decide how much smoothing to apply. This is actually the
|
||||
minimal rank, if the parameter matrix were stay fixed during training,
|
||||
but still relevant to know how robust the parameter covariance estimate is.
|
||||
"""
|
||||
smooth0 = group["param_rms_smooth0"]
|
||||
smooth1 = group["param_rms_smooth1"]
|
||||
eps = group["eps"]
|
||||
size, = rms.shape
|
||||
smooth = (smooth0 +
|
||||
(smooth1 - smooth0) * size / (size + rank))
|
||||
mean = rms.mean()
|
||||
rms += eps + smooth * mean
|
||||
new_mean = (eps + (smooth + 1) * mean)
|
||||
return rms / new_mean
|
||||
|
||||
|
||||
def _estimate_proj(self,
|
||||
@ -1241,7 +988,7 @@ class LearnedGradient(Optimizer):
|
||||
param_freqs = param_freqs.tolist()
|
||||
param_periods = param_periods.tolist()
|
||||
|
||||
logging.info(f"LearnedGradient._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}")
|
||||
logging.info(f"PrAdam._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}")
|
||||
print_info = random.random() < 0.01
|
||||
i = 0
|
||||
for p in group["params"]:
|
||||
@ -1256,6 +1003,27 @@ class LearnedGradient(Optimizer):
|
||||
# Do nothing more, as yet.
|
||||
|
||||
|
||||
def _svd(x: Tensor):
|
||||
# Wrapper for torch svd that catches errors and re-tries.
|
||||
randU = None
|
||||
for i in range(4):
|
||||
try:
|
||||
U, S, V = x.svd()
|
||||
s = U.sum()
|
||||
if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01):
|
||||
raise RuntimeError(f"svd failed (or random test), sum={s}")
|
||||
if randU is not None:
|
||||
U = torch.matmul(randU.t(), U)
|
||||
return U, S, V # success
|
||||
except:
|
||||
logging.warning(f"svd failed: i={i}, x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely "
|
||||
f"error in SVD. Will try again after random change.")
|
||||
U, S, V = torch.randn_like(x).svd()
|
||||
x = torch.matmul(U, x)
|
||||
if randU is None:
|
||||
randU = U
|
||||
else:
|
||||
randU = torch.matmul(U, randU)
|
||||
|
||||
|
||||
|
||||
@ -1902,8 +1670,8 @@ def _test_eve_cain():
|
||||
|
||||
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
||||
elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
|
||||
elif iter == 2: optim = LearnedGradient(m.parameters(), lr=0.03)
|
||||
elif iter == 3: optim = LearnedGradient(m.parameters(), lr=0.03)
|
||||
elif iter == 2: optim = PrAdam(m.parameters(), lr=0.03)
|
||||
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
||||
|
||||
start = timeit.default_timer()
|
||||
@ -1965,11 +1733,18 @@ def _test_eve_cain():
|
||||
|
||||
logging.info("Normalized output row magnitudes log-stddev: {stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log())}")
|
||||
|
||||
|
||||
def _test_svd():
|
||||
device = 'cuda'
|
||||
for i in range(1000):
|
||||
X = torch.eye(100, device=device) + 1.0e-08 * torch.randn(100, 100, device=device)
|
||||
U, S, V = _svd(X)
|
||||
X2 = torch.matmul(U*S, V.t())
|
||||
assert torch.allclose(X2, X, atol=0.001)
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
#_test_svd()
|
||||
_test_eve_cain()
|
||||
#_test_eden()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user