Big simplification to update rule
This commit is contained in:
parent
a80a8abf0c
commit
105d49d31b
@ -427,16 +427,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
else:
|
||||
if self._is_lr_update_step(group, state):
|
||||
self._update_param_cov(group, p, state)
|
||||
|
||||
P_proj = self._compute_bases(group, p.shape, state)
|
||||
# Only update the parameter-dependent part of the learning
|
||||
# rate matrices at most every other time we reach here, and
|
||||
# less frequently than that later in training.
|
||||
#self._update_param_scales(group, p, state, P_proj)
|
||||
#self._update_param_scales_simple(group, p, state, P_proj)
|
||||
|
||||
# We won't be doing this any more.
|
||||
#self._diagonalize_grad_cov(group, p, state)
|
||||
self._compute_bases(group, p.shape, state)
|
||||
self._zero_exp_avg_sq(state)
|
||||
if step % grad_cov_period == 0:
|
||||
self._update_grad_cov(group, p, state)
|
||||
@ -588,247 +579,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
return (step >= zero_step + cur_update_period)
|
||||
|
||||
|
||||
def _update_param_scales_simple(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
P_proj: List[Optional[Tensor]]) -> None:
|
||||
for dim in range(1, p.ndim):
|
||||
size = p.shape[dim]
|
||||
try:
|
||||
Q = state[f"Q_{dim}"]
|
||||
except KeyError:
|
||||
assert size == 1 or size == numel, size
|
||||
continue # e.g. size == 1 or size == numel
|
||||
|
||||
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
||||
this_P_proj = P_proj[dim].reshape(batch_size, num_blocks, block_size, 1)
|
||||
|
||||
# The following normalization step will ensure the Frobenius
|
||||
# norm is unchanged, from applying this scale: at least,
|
||||
# assuming "grad / denom" gives uncorrelated outputs so that
|
||||
# they will have equal variances after projecting to the space
|
||||
# where the parameter var is diagonalized... this is *roughly*
|
||||
# true because the gradients at the point where we compute "grad
|
||||
# / denom" should be decorrelated at least considering
|
||||
# individual tensor dims
|
||||
this_P_proj /= _mean(this_P_proj, exclude_dims=[0], keepdim=True)
|
||||
|
||||
if True:
|
||||
# debug info.
|
||||
scale = this_P_proj.sqrt()
|
||||
step = state["step"]
|
||||
scale_min, scale_max, scale_mean = scale.min().item(), scale.max().item(), scale.mean().item()
|
||||
logging.info(f"step={step}, dim={dim}, size={size}, scale min,max,mean={scale_min,scale_max,scale_mean}")
|
||||
|
||||
Q *= this_P_proj.sqrt()
|
||||
logging.info(f"Q rms = {(Q**2).mean().sqrt()}, abs-rms = {Q.abs().mean()}")
|
||||
|
||||
def _update_param_scales(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
P_proj: List[Optional[Tensor]]) -> None:
|
||||
"""
|
||||
Modifies the scales on the rows of the learning-rate matrices Q for each dim of this tensor,
|
||||
to take into account the estimated parameter covariance.
|
||||
|
||||
Args:
|
||||
group: dict to look up configuration values
|
||||
p: parameter matrix that we are updating. The learning rate matrices
|
||||
are actually factors of p, so p itself will change when we change
|
||||
them.
|
||||
state: state dict for the current parameter
|
||||
P_proj: a list indexed by dim, containing tensors of shape (batch_size, size)
|
||||
where size == p.shape[dim], which represent the parameter covariance
|
||||
projected by state[f"Q_{dim}"] (i.e. by the value of this variable
|
||||
at entry to this function)
|
||||
"""
|
||||
ndim = p.ndim
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
|
||||
if numel in p[0].shape:
|
||||
return # Nothing to do for this parameter matrix. E.g. a bias or a scalar.
|
||||
|
||||
scale_arr = [None] * ndim
|
||||
|
||||
# the small random part is to ensure there are no exact zeros, e.g.
|
||||
# if we initialized some parameters to zero.
|
||||
#
|
||||
# we are going to make "cur_p" a corrected version of the parameter
|
||||
# covariance that is projected with the orthogonal projections U, and that,
|
||||
# on the axes given by these projections, has covariance proportional
|
||||
# to that of state[f"param_cov_{dim}"] on those same axes. So it's
|
||||
# *slightly* whitened, just to match the stats, but not completely
|
||||
# whitened.
|
||||
eps = 1.0e-20
|
||||
cur_p = p + eps * torch.randn_like(p)
|
||||
|
||||
for dim in range(1, ndim):
|
||||
size = p.shape[dim]
|
||||
try:
|
||||
Q = state[f"Q_{dim}"]
|
||||
except KeyError:
|
||||
assert size == 1 or size == numel, size
|
||||
continue # e.g. size == 1 or size == numel
|
||||
|
||||
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
||||
|
||||
M = cur_p.transpose(dim, -1)
|
||||
# if p were of shape (batch_size, x, size, y, z),
|
||||
# after the next line M would be of shape
|
||||
# (batch_size, x, z, y, num_blocks, block_size)
|
||||
M = M.reshape(*M.shape[:-1], num_blocks, block_size)
|
||||
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, z, y, block_size)
|
||||
|
||||
while Q.ndim < M.ndim:
|
||||
Q = Q.unsqueeze(2)
|
||||
# Now Q is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
|
||||
# with indexes [batch_index, block_index, 1, 1, diagonalized_coordinate, canonical_coordinate],
|
||||
# so we need to transpose Q as we convert M to the diagonalized co-ordinate.
|
||||
M = torch.matmul(M, Q.transpose(-2, -1)) # (batch_size, num_blocks, x, z, y, block_size)
|
||||
M = _move_dim(M, 1, -2) # (batch_size, x, z, y, num_blocks, block_size)
|
||||
M = M.reshape(*M.shape[:-2], size) # (batch_size, x, z, y, size)
|
||||
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
|
||||
|
||||
# cur_param_var is a diagonal parameter variance over dimension `dim`,
|
||||
# of the current "slightly-whitened" parameter; it
|
||||
# will have shape [batch_size, 1, size, 1].
|
||||
cur_param_var = _mean(cur_p**2,
|
||||
exclude_dims=[0,dim],
|
||||
keepdim=True) # (batch_size, 1, size, 1, 1) if dim==2
|
||||
smoothed_param_var = P_proj[dim] # (batch_size, size)
|
||||
S = smoothed_param_var.reshape(cur_param_var.shape) # (batch_size, 1, size, 1, 1) if dim==2
|
||||
|
||||
# OK, cur_param_var would have the values as S if the variance stats
|
||||
# P_proj[dim] were accumulated from this exact parameter matrix and
|
||||
# not smoothed, but actually they contain older versions of the
|
||||
# parameter covariance and they have been smoothed, so they will, in
|
||||
# general, be less extreme ("whiter spectrum"). We scale p so that
|
||||
# it matches the estimated variance P_proj[dim]; the idea is to ensure it doesn't
|
||||
# have too-extreme eigenvalues (where the stats permit).
|
||||
# Actually we could just use P_proj[dim].sqrt(), suitably scaled,
|
||||
# as the scales on the rows of Q (see _update_param_scales_simple() which does
|
||||
# exactly this), but there is a problem of "counting things twice"
|
||||
# which is easiest to understand for a 2-dimensional tensor, where the
|
||||
# singular values show up identically in the covariance over either axis.
|
||||
# The estimation procedure in this function avoids the "counting things twice"
|
||||
# problem, at the expense of quite a bit of extra complexity.
|
||||
scale = (S / cur_param_var.clamp(min=eps)).sqrt()
|
||||
|
||||
if True:
|
||||
S_tmp = S.reshape(batch_size, size)
|
||||
cur_tmp = cur_param_var.reshape(batch_size, size)
|
||||
scale_tmp = scale.reshape(batch_size, size)
|
||||
skip = 10 if size > 40 else 1
|
||||
logging.info(f"dim={dim}/{ndim}, cur_param_var={cur_tmp[0][::skip]}, S={S_tmp[0][::skip]}, scale={scale_tmp[0][::skip]}")
|
||||
|
||||
|
||||
if random.random() < 0.01:
|
||||
skip = 10 if size < 20 else 1
|
||||
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::skip]}, cur_param_var={cur_param_var[0].flatten()[::skip]}, S={S[0].flatten()[::skip]}")
|
||||
|
||||
# scale shape: (batch_size, 1, size, 1, 1)
|
||||
cur_p *= scale
|
||||
|
||||
# OK, at this point we have a matrix cur_p that is (somewhat)
|
||||
# diagonalized by orthogonal matrices in each non-trivial dim, that is
|
||||
# also multiplied by scalars to match the accumulated covariance stats,
|
||||
# i.e. "slightly whitened" (in general this will make a modest
|
||||
# difference, making the eigenvalue distribution a bit flatter). Now we
|
||||
# will work out the "scaling" part of the learning-rate matrices Q. We
|
||||
# can't do this independenty for each dim, because there is a risk of
|
||||
# "counting things twice". E.g. for a matrix with 2 dims, if we do SVD
|
||||
# M = U S V^T, if we consider the covariance on the left and the right,
|
||||
# S will be reflected in both covariances, so it doesn't make sense to
|
||||
# correct for S twice. Not all parameter matrices have exactly 2 dims,
|
||||
# and also we're dealing with accumulated parameter stats which makes
|
||||
# things not quite so simple, so we don't want to just take the sqrt of
|
||||
# S.
|
||||
|
||||
# cur_scales[dim] will be a 1-d tensor with shapes like (batch_size, 1, 1, size, 1),
|
||||
# containing the scales on the learning-rate matrix for this dimension.
|
||||
# we apply these scales to the parameter matrix before estimating the
|
||||
# cur_scales for the other dims.
|
||||
cur_scales = [None] * ndim
|
||||
|
||||
debug = (random.random() < 0.1)
|
||||
for i in range(4): # for 4 iterations (this is quite arbitrary)
|
||||
for dim in range(1, ndim):
|
||||
size = p.shape[dim]
|
||||
if not f"Q_{dim}" in state:
|
||||
assert size == 1 or size == numel, size
|
||||
continue
|
||||
if cur_scales[dim] is not None:
|
||||
# correct for the fact that we have already normalized this dim in cur_p,
|
||||
# i.e. remove the scaling for *this* dim, while keeping the scaling for
|
||||
# all other dims.
|
||||
cur_p *= cur_scales[dim] # cur_scales shape: (batch_size, 1, size, 1, 1)
|
||||
# rms shape: (batch_size, 1, size, 1, 1)
|
||||
rms = _mean(cur_p**2, exclude_dims=[0,dim], keepdim=True).sqrt()
|
||||
rank = numel // size
|
||||
# TODO: consider more smoothing here???
|
||||
cur_scales[dim] = rms
|
||||
cur_p /= rms
|
||||
|
||||
if debug:
|
||||
def _summarize(rms):
|
||||
rms = rms[0] # get rid of batch dim by selecting one example
|
||||
rms = rms.flatten()
|
||||
return rms[::10] # subset one every ten items
|
||||
logging.info(f"i={i} shape={tuple(p.shape)}, dim={dim}, rank={rank}, size={size}, rms={_summarize(rms)}")
|
||||
|
||||
# Apply the scales in `cur_scales` to Q for each dim; this reflects the
|
||||
# parameter rms values in the parameter-diagonalized space, that we have
|
||||
# estimated in the loop above.
|
||||
#
|
||||
# We normalize the scales in such a way that the Frobenius norm
|
||||
# after projecting (grad / denom) with Q should be unchanged, i.e. the
|
||||
# same as (grad / denom), which is equivalent to having rms=1.0 due
|
||||
# to how denom is constructed. This simplifies the normalization of the overall
|
||||
# scale of the parameter change: we just have to multiply by the learning
|
||||
# rate and param_rms.
|
||||
for dim in range(1, ndim):
|
||||
if cur_scales[dim] is not None:
|
||||
size = p.shape[dim]
|
||||
Q = state[f"Q_{dim}"]
|
||||
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
||||
|
||||
scale = cur_scales[dim].reshape(batch_size, num_blocks, block_size, 1)
|
||||
|
||||
# Geometrically interpolate scale with P_proj[dim].sqrt()
|
||||
P_proj_weight = 0.5
|
||||
scale = ((scale ** (1-P_proj_weight)) *
|
||||
(P_proj[dim].reshape(batch_size, num_blocks, block_size, 1) ** (P_proj_weight * 0.5)))
|
||||
|
||||
# The following normalization step will ensure the Frobenius
|
||||
# norm is unchanged, from applying this scale: at least,
|
||||
# assuming "grad / denom" gives uncorrelated outputs so that
|
||||
# they will have equal variances after projecting to the space
|
||||
# where the parameter var is diagonalized... this is *roughly*
|
||||
# true because the gradients at the point where we compute "grad
|
||||
# / denom" should be decorrelated at least considering
|
||||
# individual tensor dims
|
||||
scale /= _mean(scale**2, exclude_dims=[0], keepdim=True).sqrt()
|
||||
|
||||
|
||||
if True:
|
||||
# debug info.
|
||||
step = state["step"]
|
||||
scale_min, scale_max, scale_mean = scale.min().item(), scale.max().item(), scale.mean().item()
|
||||
logging.info(f"step={step}, dim={dim}, size={size}, scale min,max,mean={scale_min,scale_max,scale_mean}")
|
||||
|
||||
# Q is indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
|
||||
# want to multiply on the diagonalized co-ordinate.
|
||||
state[f"Q_{dim}"] *= scale
|
||||
state["last_param_scale_update"] = state["step"]
|
||||
|
||||
|
||||
def _compute_bases(self,
|
||||
group: dict,
|
||||
p_shape: torch.Size,
|
||||
state: dict) -> List[Optional[Tensor]]:
|
||||
state: dict):
|
||||
"""
|
||||
For each tensor dimension that we are preconditioning, this function sets
|
||||
state[f"Q_{dim}"] to an orthogonal matrix (per batch element and
|
||||
@ -906,6 +661,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# dimensions with zero grad separate from those with nonzero grad.
|
||||
G_prime = _diag(torch.matmul(U_g.transpose(2,3), torch.matmul(grad_cov, U_g)))
|
||||
|
||||
G_prime_noeps = G_prime.clone()
|
||||
# Use the form of the diagonalized gradient matrix that we get after
|
||||
# we add the Adam-type smoothing with epsilon.
|
||||
G_prime += (_mean(G_prime, exclude_dims=[0], keepdim=True) *(denom_rel_eps * denom_rel_eps) +
|
||||
@ -926,109 +682,32 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
#C = _fake_cholesky(P_prime)
|
||||
C = P_prime.cholesky()
|
||||
|
||||
# OK, P == U C C^T U^T.
|
||||
# A matrix that takes normally distributed data to P would
|
||||
# be U_g C, because C I C^T = C C^T = P. We can actually use *any* matrix
|
||||
# that takes normally distributed data to P, so we can use
|
||||
# U_g C U for any orthogonal U, since U_g C U I U^T C^T U_g^T == P.
|
||||
# So there is no harm in choosing a matrix U that diagonalizes the
|
||||
# projected grad_cov. grad_cov gets projected by
|
||||
|
||||
# CGC = (C^T G' C), it would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C))
|
||||
# if G were formatted as a diagonal matrix, but G is just the diagonal.
|
||||
CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), C)
|
||||
# make sure it's exactly symmetric, want to make sure SVD is exact w.r.t.
|
||||
# dimensions with zero grad and zero parameters.
|
||||
CGC = 0.5 * (CGC + CGC.transpose(2, 3))
|
||||
# this projects P; its transpose projects the gradient.
|
||||
UC = torch.matmul(U_g, C)
|
||||
|
||||
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{0.5}
|
||||
# instead of projecting grad_cov, we can just use its diagonal, forget the #
|
||||
# U_g part of the transform, and project with C.
|
||||
grad_cov_proj = torch.matmul(C.transpose(2, 3) * G_prime_noeps.unsqueeze(-1), C)
|
||||
|
||||
# (eqn:4) says: Z'^{-1} = C^{-T} (C^T G' C)^0.5 C^{-1}
|
||||
# Since (C^T G' C)^0.5 = U S^{0.5} U^T,
|
||||
# we have Z'^{-1} = X S^{0.5} X^T where X = C^{-T} U
|
||||
X = torch.triangular_solve(U, C, upper=False, transpose=True)[0]
|
||||
Z_prime_inv = torch.matmul(X * S.sqrt().unsqueeze(-2), X.transpose(2, 3))
|
||||
# make sure it's exactly symmetric, want to make sure SVD is exact w.r.t.
|
||||
# dimensions with zero grad and zero parameters.
|
||||
Z_prime_inv = 0.5 * (Z_prime_inv + Z_prime_inv.transpose(2, 3))
|
||||
# OK, grad_cov is diagonalized by U^T C^T U_g^T. So the projection that we
|
||||
# apply to the param cov is U_g C U
|
||||
U, S, _ = _svd(grad_cov_proj)
|
||||
|
||||
if True:
|
||||
def _check_similar(x, y, name):
|
||||
ratio = (y-x).abs().sum() / (x.abs().sum() + 1.0e-20)
|
||||
if not (ratio < 0.0001):
|
||||
logging.warning(f"Check {name} failed, ratio={ratio.item()}, {x} vs. {y}")
|
||||
# proj is indexed [batch_idx,block_idx,canonical_coordinate,diagonalized_coordinate],
|
||||
# so we need to transpose to get Q_{dim}.
|
||||
proj = torch.matmul(UC, U)
|
||||
|
||||
def _check_symmetric(x, x_name):
|
||||
diff = x - x.transpose(-2, -1)
|
||||
ratio = diff.abs().sum() / x.abs().sum()
|
||||
if not (ratio < 0.0001):
|
||||
logging.warning(f"{x_name} is not symmetric: ratio={ratio.item()}")
|
||||
Q[:] = proj.transpose(2, 3)
|
||||
|
||||
_check_similar(P_prime, torch.matmul(C, C.transpose(2, 3)), "cholesky_check")
|
||||
_check_similar(torch.matmul(C.transpose(2, 3), X), U, "CTX")
|
||||
_check_symmetric(Z_prime_inv, "Z_prime_inv")
|
||||
_check_symmetric(P_prime, "P_prime")
|
||||
# A check.
|
||||
# Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2),
|
||||
# or alternatively with the inverse,
|
||||
# G_prime = Z_prime_inv P_prime Z_prime_inv
|
||||
G_prime_check = _diag(torch.matmul(Z_prime_inv, torch.matmul(P_prime, Z_prime_inv)))
|
||||
_check_similar(G_prime, G_prime_check, "G_prime")
|
||||
|
||||
|
||||
Z_inv = torch.matmul(U_g, torch.matmul(Z_prime_inv, U_g.transpose(2, 3)))
|
||||
Z_inv = 0.5 * (Z_inv + Z_inv.transpose(2, 3)) # make sure exactly symmetric
|
||||
|
||||
Z_inv_diag = _diag(Z_inv) # aliased with Z_inv
|
||||
# this is smoothing Z relative to its own diagonal. This is z_inv,
|
||||
# so by applying a minimum here, we are applying a maximum of the
|
||||
# eigs of Z after normalizing so the diagonal is 1.
|
||||
Z_inv_diag *= (1. + group["cov_min"][4])
|
||||
|
||||
# We really want the SVD on Z, which will be used for the learning-rate matrix
|
||||
# Q, but Z_prime is better, numerically, to work on because it's closer to
|
||||
# being diagonalized.
|
||||
U_z, S_z_inv, _ = _svd(Z_inv)
|
||||
|
||||
|
||||
if True:
|
||||
skip = 10 if S.shape[-1] > 40 else 1
|
||||
logging.info(f"dim={dim}, G_prime is {G_prime[0,0,::skip]}, Eigs of Z_inv are: {S_z_inv[0,0,::skip]}")
|
||||
|
||||
# state[f"Q_{dim}"] is indexed: [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate].
|
||||
# so we need to transpose U_z as U_z is indexed
|
||||
# [batch_idx, block_idx, canonical_coordinate, diagonalized_coordinate]
|
||||
Q[:] = U_z.transpose(2, 3)
|
||||
|
||||
# Work out the diagonal P_proj_diag of the projected smoothed parameter covariance P_proj, which is P
|
||||
# projected in the basis U_z. This will be used to get the parameter scales for the
|
||||
# bases Q_{dim}. Now,
|
||||
# P = U_g P' U_g^T,
|
||||
# and P_proj = U_z^T P U_z,
|
||||
# so P_proj = (U_z^T U_g) P' (U_z^T U_g)^T
|
||||
U_prod = torch.matmul(U_z.transpose(2, 3), U_g)
|
||||
# this_P_proj shape: (batch_size, num_blocks, block_size)
|
||||
this_P_proj_diag = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3))))
|
||||
P_proj[dim] = this_P_proj_diag.clone().reshape(batch_size, size)
|
||||
|
||||
simple_update = True
|
||||
if simple_update:
|
||||
# normalize the scales in a way that preserves the Frobenius norm of the
|
||||
# projected parameter deltas
|
||||
P_rms = this_P_proj_diag / _mean(this_P_proj_diag, exclude_dims=[0], keepdim=True)
|
||||
scale = P_rms.unsqueeze(-1).sqrt()
|
||||
Q *= scale
|
||||
logging.info(f"Q rms = {(Q**2).mean().sqrt()} abs-rms = {Q.abs().mean()}")
|
||||
# no iterative stuff, just use sqrt(P_proj) as scale on Q. If this is False, we need to
|
||||
# call self._update_param_scales(...) from the calling function.
|
||||
if True:
|
||||
# debug output
|
||||
step = state["step"]
|
||||
scale_min, scale_max, scale_mean = scale.min().item(), scale.max().item(), scale.mean().item()
|
||||
logging.info(f"step={step}, dim={dim}, size={size}, scale min,max,mean={scale_min,scale_max,scale_mean}")
|
||||
if True:
|
||||
# debug output
|
||||
this_P_proj_unsmoothed = _diag(torch.matmul(U_prod, torch.matmul(P_prime_unsmoothed,
|
||||
U_prod.transpose(2, 3))))
|
||||
G_proj_unsmoothed = _diag(torch.matmul(U_prod * G_prime.unsqueeze(-2), U_prod.transpose(2, 3)))
|
||||
skip = 10 if block_size > 40 else 1
|
||||
logging.info(f"dim={dim}, diag of P_proj is: {this_P_proj_diag[0,0,::skip]}, diag of unsmoothed P_proj is: {this_P_proj_unsmoothed[0,0,::skip]}, diag of unsmoothed G_proj is {G_proj_unsmoothed[0,0,::skip]}")
|
||||
|
||||
# P_proj won't be needed if simple_update == True.
|
||||
return P_proj
|
||||
continue
|
||||
|
||||
|
||||
def _smooth_param_cov(self,
|
||||
@ -1199,84 +878,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
return X
|
||||
|
||||
|
||||
def _diagonalize_grad_cov(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
"""
|
||||
Called from _update_lrs(), this function diagonalizes the gradient covariance
|
||||
state[f"grad_cov_{dim}"] by modifying the projections state[f"Q_{dim}"]: specifically,
|
||||
by left-multiplying the projections by an orthogonal matrix.
|
||||
"""
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
for dim in range(1, p.ndim): # dim 0 is batch dim
|
||||
size = p.shape[dim]
|
||||
try:
|
||||
# A and grad_cov shape: (batch_size, num_blocks, block_size, block_size)
|
||||
Q = state[f"Q_{dim}"]
|
||||
grad_cov = state[f"grad_cov_{dim}"]
|
||||
except KeyError:
|
||||
assert size == 1 or size == numel
|
||||
continue
|
||||
|
||||
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
|
||||
# the -1 represents all other tensor dims treated as a batch dimension.
|
||||
# M_grad is the same shape as M. We could write a pseudo-loss as
|
||||
# loss = tr(M_grad^T M)
|
||||
# Because we can decompose M as M == N Q, we can write:
|
||||
# loss = tr(M_grad^T N Q) = tr(Q M_grad^T N),
|
||||
# so we can write this at tr(N_grad^T N),
|
||||
# where N_grad == (Q M_grad^T)^T = M_grad Q^T.
|
||||
# Now,
|
||||
# grad_cov == M_grad^T M_grad,
|
||||
# decaying-averaged over minibatches; this is of shape (size,size).
|
||||
# Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N
|
||||
# (which is what we want to diagonalize), as::
|
||||
# N_grad_cov = N_grad^T N_grad
|
||||
# = Q M_grad^T M_grad Q^T
|
||||
# = Q grad_cov Q^T
|
||||
# (note: this makes sense because the 1st index of Q is the diagonalized
|
||||
# index).
|
||||
|
||||
def dispersion(v):
|
||||
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
|
||||
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
|
||||
|
||||
|
||||
# N_grad_cov shape: (batch_size, num_blocks, block_size, block_size)
|
||||
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(2, 3)))
|
||||
N_grad_cov = N_grad_cov + N_grad_cov.transpose(2, 3) # ensure symmetric
|
||||
U, S, V = _svd(N_grad_cov)
|
||||
if random.random() < 0.01:
|
||||
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
|
||||
f"changed from {dispersion(_diag(N_grad_cov))} to {dispersion(S)}")
|
||||
|
||||
# N_grad_cov is SPD, so
|
||||
# N_grad_cov = U S U^T.
|
||||
|
||||
# Now, we can diagonalize N_grad_cov with:
|
||||
# U^T N_grad_cov U == S.
|
||||
# N_grad_cov is a sum of N_grad^T N_grad.
|
||||
# We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal.
|
||||
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
|
||||
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
|
||||
#
|
||||
# which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted
|
||||
# as tr(hat_N_grad hat_N), where:
|
||||
# hat_N_grad = N_grad U
|
||||
# hat_N = N U
|
||||
# (hat_N means \hat{N}, or N with a hat on it).
|
||||
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
|
||||
# hat_N will be diagonalized. We also modify Q to hat_Q when
|
||||
# we modify hat_N, to keep the product M unchanged:
|
||||
# M = N Q = N U U^T Q = hat_N hat_Q
|
||||
# This can be done by setting
|
||||
# hat_Q := U^T Q (eq.10)
|
||||
#
|
||||
# This is the only thing we have to do, as N is implicit
|
||||
# and not materialized at this point.
|
||||
Q[:] = torch.matmul(U.transpose(2, 3), Q)
|
||||
|
||||
def _update_grad_cov(self,
|
||||
group: dict,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user