Big simplification to update rule

This commit is contained in:
Daniel Povey 2022-07-30 00:21:12 -07:00
parent a80a8abf0c
commit 105d49d31b

View File

@ -427,16 +427,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
else:
if self._is_lr_update_step(group, state):
self._update_param_cov(group, p, state)
P_proj = self._compute_bases(group, p.shape, state)
# Only update the parameter-dependent part of the learning
# rate matrices at most every other time we reach here, and
# less frequently than that later in training.
#self._update_param_scales(group, p, state, P_proj)
#self._update_param_scales_simple(group, p, state, P_proj)
# We won't be doing this any more.
#self._diagonalize_grad_cov(group, p, state)
self._compute_bases(group, p.shape, state)
self._zero_exp_avg_sq(state)
if step % grad_cov_period == 0:
self._update_grad_cov(group, p, state)
@ -588,247 +579,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
return (step >= zero_step + cur_update_period)
def _update_param_scales_simple(self,
group: dict,
p: Tensor,
state: dict,
P_proj: List[Optional[Tensor]]) -> None:
for dim in range(1, p.ndim):
size = p.shape[dim]
try:
Q = state[f"Q_{dim}"]
except KeyError:
assert size == 1 or size == numel, size
continue # e.g. size == 1 or size == numel
(batch_size, num_blocks, block_size, block_size) = Q.shape
this_P_proj = P_proj[dim].reshape(batch_size, num_blocks, block_size, 1)
# The following normalization step will ensure the Frobenius
# norm is unchanged, from applying this scale: at least,
# assuming "grad / denom" gives uncorrelated outputs so that
# they will have equal variances after projecting to the space
# where the parameter var is diagonalized... this is *roughly*
# true because the gradients at the point where we compute "grad
# / denom" should be decorrelated at least considering
# individual tensor dims
this_P_proj /= _mean(this_P_proj, exclude_dims=[0], keepdim=True)
if True:
# debug info.
scale = this_P_proj.sqrt()
step = state["step"]
scale_min, scale_max, scale_mean = scale.min().item(), scale.max().item(), scale.mean().item()
logging.info(f"step={step}, dim={dim}, size={size}, scale min,max,mean={scale_min,scale_max,scale_mean}")
Q *= this_P_proj.sqrt()
logging.info(f"Q rms = {(Q**2).mean().sqrt()}, abs-rms = {Q.abs().mean()}")
def _update_param_scales(self,
group: dict,
p: Tensor,
state: dict,
P_proj: List[Optional[Tensor]]) -> None:
"""
Modifies the scales on the rows of the learning-rate matrices Q for each dim of this tensor,
to take into account the estimated parameter covariance.
Args:
group: dict to look up configuration values
p: parameter matrix that we are updating. The learning rate matrices
are actually factors of p, so p itself will change when we change
them.
state: state dict for the current parameter
P_proj: a list indexed by dim, containing tensors of shape (batch_size, size)
where size == p.shape[dim], which represent the parameter covariance
projected by state[f"Q_{dim}"] (i.e. by the value of this variable
at entry to this function)
"""
ndim = p.ndim
batch_size = p.shape[0]
numel = p.numel() // batch_size
if numel in p[0].shape:
return # Nothing to do for this parameter matrix. E.g. a bias or a scalar.
scale_arr = [None] * ndim
# the small random part is to ensure there are no exact zeros, e.g.
# if we initialized some parameters to zero.
#
# we are going to make "cur_p" a corrected version of the parameter
# covariance that is projected with the orthogonal projections U, and that,
# on the axes given by these projections, has covariance proportional
# to that of state[f"param_cov_{dim}"] on those same axes. So it's
# *slightly* whitened, just to match the stats, but not completely
# whitened.
eps = 1.0e-20
cur_p = p + eps * torch.randn_like(p)
for dim in range(1, ndim):
size = p.shape[dim]
try:
Q = state[f"Q_{dim}"]
except KeyError:
assert size == 1 or size == numel, size
continue # e.g. size == 1 or size == numel
(batch_size, num_blocks, block_size, block_size) = Q.shape
M = cur_p.transpose(dim, -1)
# if p were of shape (batch_size, x, size, y, z),
# after the next line M would be of shape
# (batch_size, x, z, y, num_blocks, block_size)
M = M.reshape(*M.shape[:-1], num_blocks, block_size)
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, z, y, block_size)
while Q.ndim < M.ndim:
Q = Q.unsqueeze(2)
# Now Q is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
# with indexes [batch_index, block_index, 1, 1, diagonalized_coordinate, canonical_coordinate],
# so we need to transpose Q as we convert M to the diagonalized co-ordinate.
M = torch.matmul(M, Q.transpose(-2, -1)) # (batch_size, num_blocks, x, z, y, block_size)
M = _move_dim(M, 1, -2) # (batch_size, x, z, y, num_blocks, block_size)
M = M.reshape(*M.shape[:-2], size) # (batch_size, x, z, y, size)
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
# cur_param_var is a diagonal parameter variance over dimension `dim`,
# of the current "slightly-whitened" parameter; it
# will have shape [batch_size, 1, size, 1].
cur_param_var = _mean(cur_p**2,
exclude_dims=[0,dim],
keepdim=True) # (batch_size, 1, size, 1, 1) if dim==2
smoothed_param_var = P_proj[dim] # (batch_size, size)
S = smoothed_param_var.reshape(cur_param_var.shape) # (batch_size, 1, size, 1, 1) if dim==2
# OK, cur_param_var would have the values as S if the variance stats
# P_proj[dim] were accumulated from this exact parameter matrix and
# not smoothed, but actually they contain older versions of the
# parameter covariance and they have been smoothed, so they will, in
# general, be less extreme ("whiter spectrum"). We scale p so that
# it matches the estimated variance P_proj[dim]; the idea is to ensure it doesn't
# have too-extreme eigenvalues (where the stats permit).
# Actually we could just use P_proj[dim].sqrt(), suitably scaled,
# as the scales on the rows of Q (see _update_param_scales_simple() which does
# exactly this), but there is a problem of "counting things twice"
# which is easiest to understand for a 2-dimensional tensor, where the
# singular values show up identically in the covariance over either axis.
# The estimation procedure in this function avoids the "counting things twice"
# problem, at the expense of quite a bit of extra complexity.
scale = (S / cur_param_var.clamp(min=eps)).sqrt()
if True:
S_tmp = S.reshape(batch_size, size)
cur_tmp = cur_param_var.reshape(batch_size, size)
scale_tmp = scale.reshape(batch_size, size)
skip = 10 if size > 40 else 1
logging.info(f"dim={dim}/{ndim}, cur_param_var={cur_tmp[0][::skip]}, S={S_tmp[0][::skip]}, scale={scale_tmp[0][::skip]}")
if random.random() < 0.01:
skip = 10 if size < 20 else 1
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::skip]}, cur_param_var={cur_param_var[0].flatten()[::skip]}, S={S[0].flatten()[::skip]}")
# scale shape: (batch_size, 1, size, 1, 1)
cur_p *= scale
# OK, at this point we have a matrix cur_p that is (somewhat)
# diagonalized by orthogonal matrices in each non-trivial dim, that is
# also multiplied by scalars to match the accumulated covariance stats,
# i.e. "slightly whitened" (in general this will make a modest
# difference, making the eigenvalue distribution a bit flatter). Now we
# will work out the "scaling" part of the learning-rate matrices Q. We
# can't do this independenty for each dim, because there is a risk of
# "counting things twice". E.g. for a matrix with 2 dims, if we do SVD
# M = U S V^T, if we consider the covariance on the left and the right,
# S will be reflected in both covariances, so it doesn't make sense to
# correct for S twice. Not all parameter matrices have exactly 2 dims,
# and also we're dealing with accumulated parameter stats which makes
# things not quite so simple, so we don't want to just take the sqrt of
# S.
# cur_scales[dim] will be a 1-d tensor with shapes like (batch_size, 1, 1, size, 1),
# containing the scales on the learning-rate matrix for this dimension.
# we apply these scales to the parameter matrix before estimating the
# cur_scales for the other dims.
cur_scales = [None] * ndim
debug = (random.random() < 0.1)
for i in range(4): # for 4 iterations (this is quite arbitrary)
for dim in range(1, ndim):
size = p.shape[dim]
if not f"Q_{dim}" in state:
assert size == 1 or size == numel, size
continue
if cur_scales[dim] is not None:
# correct for the fact that we have already normalized this dim in cur_p,
# i.e. remove the scaling for *this* dim, while keeping the scaling for
# all other dims.
cur_p *= cur_scales[dim] # cur_scales shape: (batch_size, 1, size, 1, 1)
# rms shape: (batch_size, 1, size, 1, 1)
rms = _mean(cur_p**2, exclude_dims=[0,dim], keepdim=True).sqrt()
rank = numel // size
# TODO: consider more smoothing here???
cur_scales[dim] = rms
cur_p /= rms
if debug:
def _summarize(rms):
rms = rms[0] # get rid of batch dim by selecting one example
rms = rms.flatten()
return rms[::10] # subset one every ten items
logging.info(f"i={i} shape={tuple(p.shape)}, dim={dim}, rank={rank}, size={size}, rms={_summarize(rms)}")
# Apply the scales in `cur_scales` to Q for each dim; this reflects the
# parameter rms values in the parameter-diagonalized space, that we have
# estimated in the loop above.
#
# We normalize the scales in such a way that the Frobenius norm
# after projecting (grad / denom) with Q should be unchanged, i.e. the
# same as (grad / denom), which is equivalent to having rms=1.0 due
# to how denom is constructed. This simplifies the normalization of the overall
# scale of the parameter change: we just have to multiply by the learning
# rate and param_rms.
for dim in range(1, ndim):
if cur_scales[dim] is not None:
size = p.shape[dim]
Q = state[f"Q_{dim}"]
(batch_size, num_blocks, block_size, block_size) = Q.shape
scale = cur_scales[dim].reshape(batch_size, num_blocks, block_size, 1)
# Geometrically interpolate scale with P_proj[dim].sqrt()
P_proj_weight = 0.5
scale = ((scale ** (1-P_proj_weight)) *
(P_proj[dim].reshape(batch_size, num_blocks, block_size, 1) ** (P_proj_weight * 0.5)))
# The following normalization step will ensure the Frobenius
# norm is unchanged, from applying this scale: at least,
# assuming "grad / denom" gives uncorrelated outputs so that
# they will have equal variances after projecting to the space
# where the parameter var is diagonalized... this is *roughly*
# true because the gradients at the point where we compute "grad
# / denom" should be decorrelated at least considering
# individual tensor dims
scale /= _mean(scale**2, exclude_dims=[0], keepdim=True).sqrt()
if True:
# debug info.
step = state["step"]
scale_min, scale_max, scale_mean = scale.min().item(), scale.max().item(), scale.mean().item()
logging.info(f"step={step}, dim={dim}, size={size}, scale min,max,mean={scale_min,scale_max,scale_mean}")
# Q is indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
# want to multiply on the diagonalized co-ordinate.
state[f"Q_{dim}"] *= scale
state["last_param_scale_update"] = state["step"]
def _compute_bases(self,
group: dict,
p_shape: torch.Size,
state: dict) -> List[Optional[Tensor]]:
state: dict):
"""
For each tensor dimension that we are preconditioning, this function sets
state[f"Q_{dim}"] to an orthogonal matrix (per batch element and
@ -906,6 +661,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# dimensions with zero grad separate from those with nonzero grad.
G_prime = _diag(torch.matmul(U_g.transpose(2,3), torch.matmul(grad_cov, U_g)))
G_prime_noeps = G_prime.clone()
# Use the form of the diagonalized gradient matrix that we get after
# we add the Adam-type smoothing with epsilon.
G_prime += (_mean(G_prime, exclude_dims=[0], keepdim=True) *(denom_rel_eps * denom_rel_eps) +
@ -926,109 +682,32 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
#C = _fake_cholesky(P_prime)
C = P_prime.cholesky()
# OK, P == U C C^T U^T.
# A matrix that takes normally distributed data to P would
# be U_g C, because C I C^T = C C^T = P. We can actually use *any* matrix
# that takes normally distributed data to P, so we can use
# U_g C U for any orthogonal U, since U_g C U I U^T C^T U_g^T == P.
# So there is no harm in choosing a matrix U that diagonalizes the
# projected grad_cov. grad_cov gets projected by
# CGC = (C^T G' C), it would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C))
# if G were formatted as a diagonal matrix, but G is just the diagonal.
CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), C)
# make sure it's exactly symmetric, want to make sure SVD is exact w.r.t.
# dimensions with zero grad and zero parameters.
CGC = 0.5 * (CGC + CGC.transpose(2, 3))
# this projects P; its transpose projects the gradient.
UC = torch.matmul(U_g, C)
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{0.5}
# instead of projecting grad_cov, we can just use its diagonal, forget the #
# U_g part of the transform, and project with C.
grad_cov_proj = torch.matmul(C.transpose(2, 3) * G_prime_noeps.unsqueeze(-1), C)
# (eqn:4) says: Z'^{-1} = C^{-T} (C^T G' C)^0.5 C^{-1}
# Since (C^T G' C)^0.5 = U S^{0.5} U^T,
# we have Z'^{-1} = X S^{0.5} X^T where X = C^{-T} U
X = torch.triangular_solve(U, C, upper=False, transpose=True)[0]
Z_prime_inv = torch.matmul(X * S.sqrt().unsqueeze(-2), X.transpose(2, 3))
# make sure it's exactly symmetric, want to make sure SVD is exact w.r.t.
# dimensions with zero grad and zero parameters.
Z_prime_inv = 0.5 * (Z_prime_inv + Z_prime_inv.transpose(2, 3))
# OK, grad_cov is diagonalized by U^T C^T U_g^T. So the projection that we
# apply to the param cov is U_g C U
U, S, _ = _svd(grad_cov_proj)
if True:
def _check_similar(x, y, name):
ratio = (y-x).abs().sum() / (x.abs().sum() + 1.0e-20)
if not (ratio < 0.0001):
logging.warning(f"Check {name} failed, ratio={ratio.item()}, {x} vs. {y}")
# proj is indexed [batch_idx,block_idx,canonical_coordinate,diagonalized_coordinate],
# so we need to transpose to get Q_{dim}.
proj = torch.matmul(UC, U)
def _check_symmetric(x, x_name):
diff = x - x.transpose(-2, -1)
ratio = diff.abs().sum() / x.abs().sum()
if not (ratio < 0.0001):
logging.warning(f"{x_name} is not symmetric: ratio={ratio.item()}")
Q[:] = proj.transpose(2, 3)
_check_similar(P_prime, torch.matmul(C, C.transpose(2, 3)), "cholesky_check")
_check_similar(torch.matmul(C.transpose(2, 3), X), U, "CTX")
_check_symmetric(Z_prime_inv, "Z_prime_inv")
_check_symmetric(P_prime, "P_prime")
# A check.
# Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2),
# or alternatively with the inverse,
# G_prime = Z_prime_inv P_prime Z_prime_inv
G_prime_check = _diag(torch.matmul(Z_prime_inv, torch.matmul(P_prime, Z_prime_inv)))
_check_similar(G_prime, G_prime_check, "G_prime")
Z_inv = torch.matmul(U_g, torch.matmul(Z_prime_inv, U_g.transpose(2, 3)))
Z_inv = 0.5 * (Z_inv + Z_inv.transpose(2, 3)) # make sure exactly symmetric
Z_inv_diag = _diag(Z_inv) # aliased with Z_inv
# this is smoothing Z relative to its own diagonal. This is z_inv,
# so by applying a minimum here, we are applying a maximum of the
# eigs of Z after normalizing so the diagonal is 1.
Z_inv_diag *= (1. + group["cov_min"][4])
# We really want the SVD on Z, which will be used for the learning-rate matrix
# Q, but Z_prime is better, numerically, to work on because it's closer to
# being diagonalized.
U_z, S_z_inv, _ = _svd(Z_inv)
if True:
skip = 10 if S.shape[-1] > 40 else 1
logging.info(f"dim={dim}, G_prime is {G_prime[0,0,::skip]}, Eigs of Z_inv are: {S_z_inv[0,0,::skip]}")
# state[f"Q_{dim}"] is indexed: [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate].
# so we need to transpose U_z as U_z is indexed
# [batch_idx, block_idx, canonical_coordinate, diagonalized_coordinate]
Q[:] = U_z.transpose(2, 3)
# Work out the diagonal P_proj_diag of the projected smoothed parameter covariance P_proj, which is P
# projected in the basis U_z. This will be used to get the parameter scales for the
# bases Q_{dim}. Now,
# P = U_g P' U_g^T,
# and P_proj = U_z^T P U_z,
# so P_proj = (U_z^T U_g) P' (U_z^T U_g)^T
U_prod = torch.matmul(U_z.transpose(2, 3), U_g)
# this_P_proj shape: (batch_size, num_blocks, block_size)
this_P_proj_diag = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3))))
P_proj[dim] = this_P_proj_diag.clone().reshape(batch_size, size)
simple_update = True
if simple_update:
# normalize the scales in a way that preserves the Frobenius norm of the
# projected parameter deltas
P_rms = this_P_proj_diag / _mean(this_P_proj_diag, exclude_dims=[0], keepdim=True)
scale = P_rms.unsqueeze(-1).sqrt()
Q *= scale
logging.info(f"Q rms = {(Q**2).mean().sqrt()} abs-rms = {Q.abs().mean()}")
# no iterative stuff, just use sqrt(P_proj) as scale on Q. If this is False, we need to
# call self._update_param_scales(...) from the calling function.
if True:
# debug output
step = state["step"]
scale_min, scale_max, scale_mean = scale.min().item(), scale.max().item(), scale.mean().item()
logging.info(f"step={step}, dim={dim}, size={size}, scale min,max,mean={scale_min,scale_max,scale_mean}")
if True:
# debug output
this_P_proj_unsmoothed = _diag(torch.matmul(U_prod, torch.matmul(P_prime_unsmoothed,
U_prod.transpose(2, 3))))
G_proj_unsmoothed = _diag(torch.matmul(U_prod * G_prime.unsqueeze(-2), U_prod.transpose(2, 3)))
skip = 10 if block_size > 40 else 1
logging.info(f"dim={dim}, diag of P_proj is: {this_P_proj_diag[0,0,::skip]}, diag of unsmoothed P_proj is: {this_P_proj_unsmoothed[0,0,::skip]}, diag of unsmoothed G_proj is {G_proj_unsmoothed[0,0,::skip]}")
# P_proj won't be needed if simple_update == True.
return P_proj
continue
def _smooth_param_cov(self,
@ -1199,84 +878,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
return X
def _diagonalize_grad_cov(self,
group: dict,
p: Tensor,
state: dict) -> None:
"""
Called from _update_lrs(), this function diagonalizes the gradient covariance
state[f"grad_cov_{dim}"] by modifying the projections state[f"Q_{dim}"]: specifically,
by left-multiplying the projections by an orthogonal matrix.
"""
batch_size = p.shape[0]
numel = p.numel() // batch_size
for dim in range(1, p.ndim): # dim 0 is batch dim
size = p.shape[dim]
try:
# A and grad_cov shape: (batch_size, num_blocks, block_size, block_size)
Q = state[f"Q_{dim}"]
grad_cov = state[f"grad_cov_{dim}"]
except KeyError:
assert size == 1 or size == numel
continue
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
# the -1 represents all other tensor dims treated as a batch dimension.
# M_grad is the same shape as M. We could write a pseudo-loss as
# loss = tr(M_grad^T M)
# Because we can decompose M as M == N Q, we can write:
# loss = tr(M_grad^T N Q) = tr(Q M_grad^T N),
# so we can write this at tr(N_grad^T N),
# where N_grad == (Q M_grad^T)^T = M_grad Q^T.
# Now,
# grad_cov == M_grad^T M_grad,
# decaying-averaged over minibatches; this is of shape (size,size).
# Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N
# (which is what we want to diagonalize), as::
# N_grad_cov = N_grad^T N_grad
# = Q M_grad^T M_grad Q^T
# = Q grad_cov Q^T
# (note: this makes sense because the 1st index of Q is the diagonalized
# index).
def dispersion(v):
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
# N_grad_cov shape: (batch_size, num_blocks, block_size, block_size)
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.transpose(2, 3)))
N_grad_cov = N_grad_cov + N_grad_cov.transpose(2, 3) # ensure symmetric
U, S, V = _svd(N_grad_cov)
if random.random() < 0.01:
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
f"changed from {dispersion(_diag(N_grad_cov))} to {dispersion(S)}")
# N_grad_cov is SPD, so
# N_grad_cov = U S U^T.
# Now, we can diagonalize N_grad_cov with:
# U^T N_grad_cov U == S.
# N_grad_cov is a sum of N_grad^T N_grad.
# We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal.
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
#
# which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted
# as tr(hat_N_grad hat_N), where:
# hat_N_grad = N_grad U
# hat_N = N U
# (hat_N means \hat{N}, or N with a hat on it).
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
# hat_N will be diagonalized. We also modify Q to hat_Q when
# we modify hat_N, to keep the product M unchanged:
# M = N Q = N U U^T Q = hat_N hat_Q
# This can be done by setting
# hat_Q := U^T Q (eq.10)
#
# This is the only thing we have to do, as N is implicit
# and not materialized at this point.
Q[:] = torch.matmul(U.transpose(2, 3), Q)
def _update_grad_cov(self,
group: dict,