this version not working great

This commit is contained in:
Daniel Povey 2022-07-30 21:14:03 -07:00
parent 790e8c4ba9
commit cb67540cdc

View File

@ -655,55 +655,51 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# is the same as param_cov and grad_cov.
#
# G_prime is the diagonalized grad_cov, G' above, of shape (batch_size, num_blocks, block_size).
U_g, G_prime, _ = _svd(grad_cov)
# Recompute G_prime via matrix multiplication, as svd does not seem to produce zeros on
# the singular values in places where it should be exactly zero. This is to keep
# dimensions with zero grad separate from those with nonzero grad.
G_prime = _diag(torch.matmul(U_g.transpose(2,3), torch.matmul(grad_cov, U_g)))
G_prime_noeps = G_prime.clone()
G = grad_cov.clone()
# Use the form of the diagonalized gradient matrix that we get after
# we add the Adam-type smoothing with epsilon.
G_prime += (_mean(G_prime, exclude_dims=[0], keepdim=True) *(denom_rel_eps * denom_rel_eps) +
(eps * eps))
# P_prime is P' above, which represents param_cov in the basis that diagonalizes G_prime.
# It is not smoothed yet.
P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g))
G_diag = _diag(G) # aliased with G
G_diag += (_mean(G_diag, exclude_dims=[0], keepdim=True) *(denom_rel_eps * denom_rel_eps) +
(eps * eps))
P_prime_unsmoothed = P_prime
P_prime = self._smooth_param_cov(group, p_shape, P_prime, G_prime)
P_unsmoothed = param_cov
P = param_cov.clone()
P = self._smooth_param_cov(group, p_shape, P, G)
# C will satisfy: P_prime == torch.matmul(C, C.transpose(2, 3))
if True:
_,S,_ = _svd(P)
logging.info(f"Eigs of P are {S}")
# C will satisfy: P == torch.matmul(C, C.transpose(2, 3))
# C is of shape (batch_size, num_blocks, block_size, block_size).
#def _fake_cholesky(X):
# U_, S_, _ = _svd(X)
# return U_ * S_.sqrt().unsqueeze(-2)
#C = _fake_cholesky(P_prime)
C = P_prime.cholesky()
#C = _fake_cholesky(P)
C = P.cholesky()
# OK, P == U C C^T U^T.
# A matrix that takes normally distributed data to P would
# be U_g C, because C I C^T = C C^T = P. We can actually use *any* matrix
# be C, because C I C^T = C C^T = P. We can actually use *any* matrix
# that takes normally distributed data to P, so we can use
# U_g C U for any orthogonal U, since U_g C U I U^T C^T U_g^T == P.
# C U for any orthogonal U, since C U I U^T C^T == P.
# So there is no harm in choosing a matrix U that diagonalizes the
# projected grad_cov. grad_cov gets projected by
# this projects P; its transpose projects the gradient.
UC = torch.matmul(U_g, C)
grad_cov_proj = torch.matmul(C.transpose(2, 3),
torch.matmul(grad_cov, C))
# instead of projecting grad_cov, we can just use its diagonal, forget the #
# U_g part of the transform, and project with C.
grad_cov_proj = torch.matmul(C.transpose(2, 3) * G_prime_noeps.unsqueeze(-1), C)
# OK, grad_cov is diagonalized by U^T C^T U_g^T. So the projection that we
# apply to the param cov is U_g C U
# OK, grad_cov is diagonalized by U^T C^T. So the projection that we
# apply to the param cov is C U
U, S, _ = _svd(grad_cov_proj)
# proj is indexed [batch_idx,block_idx,canonical_coordinate,diagonalized_coordinate],
# so we need to transpose to get Q_{dim}.
proj = torch.matmul(UC, U)
proj = torch.matmul(C, U)
Q[:] = proj.transpose(2, 3)
@ -713,20 +709,19 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
def _smooth_param_cov(self,
group: dict,
p_shape: torch.Size,
P_prime: Tensor,
G_prime: Tensor) -> Tensor:
P: Tensor,
G: Tensor) -> Tensor:
"""
This function returns a modified/smoothed version of the parameter covariance
P_prime.
P.
Args:
group: dict to look up config values
p_shape: The shape of the parameter we are optimizing
P_prime: a Tensor of shape (batch_size, num_blocks, block_size, block_size),
containing the parameter covariance in a basis that diagonalizes the
gradient covariance.
G_prime: the diagonalized gradient covariance, of shape (batch_size, num_blocks,
block_size)
P: a Tensor of shape (batch_size, num_blocks, block_size, block_size),
containing the parameter covariance
G: the gradient covariance, of shape (batch_size, num_blocks,
block_size, block_size)
state[f"param_cov_{dim}"], which is an estimate of the covariance of the parameter
@ -739,7 +734,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
"""
# do smoothing based on 'rank',
# that is intended to compensate for bad estimates of P.
(batch_size, num_blocks, block_size, block_size) = P_prime.shape
(batch_size, num_blocks, block_size, block_size) = P.shape
# `rank_per_block` is the rank of each block of P_prime if we were to estimate it from just one
# parameter tensor. We average it over time, but actually it won't be changing
# too much, so `rank` does tell us something.
@ -763,48 +758,60 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# we don't need to multiply `smooth` by anything, because at this point, P_prime should have
# diagonal elements close to 1.
P_prime_diag = _diag(P_prime)
P_prime_diag_mean = _mean(P_prime_diag, exclude_dims=[0], keepdim=True)
P_prime_diag += smooth * P_prime_diag_mean
P = P.clone()
P_diag = _diag(P)
P_diag_mean = _mean(P_diag, exclude_dims=[0], keepdim=True)
P_diag += smooth * P_diag_mean
G = G.clone()
G_diag = _diag(G) # aliased
G_diag *= 1.01 # improve its condition, for numerical reasons.
G = self._smooth_cov(G,
group["cov_min"][3],
group["cov_max"][3],
group["cov_pow"][3])
if True:
# This block smooths G_prime.
# Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime
# is already diagonalized, the variable G_prime is just the tensor of eigenvalues.
G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True)
G_prime_min = group["cov_min"][3]
# make sure G_prime has no zero eigs, and is unit mean.
G_prime = G_prime + G_prime_min * G_prime_mean + 1.0e-20
G_prime /= _mean(G_prime, exclude_dims=[0], keepdim=True)
# it now has unit mean..
G_prime_max = group["cov_max"][3]
G_prime = 1. / (1./G_prime + 1./G_prime_max) # apply max
G_prime_pow = group["cov_pow"][3]
if G_prime_pow != 1.0:
G_prime = G_prime ** G_prime_pow
# C C^T == G.
C = G.cholesky()
# treat the last dim of C as being in an arbitrary space, its next-to-last dim
# is the "canonical" one that we need to sum with the dims of P.
P_gnorm = torch.matmul(C.transpose(2, 3),
torch.matmul(P, C))
G_prime_rms = G_prime.sqrt()
G_prime_scale = G_prime_rms.unsqueeze(-1) * G_prime_rms.unsqueeze(-2)
# P_gnorm is a version of P_prime that is multiplied by G (actually
# G^{0.5} P_prime G^{0.5}), so that it reflects the amount of
# loss-function change in each dimension. We also tried smoothing
# a version of P_prime divided by G, but it seemed not to be helpful.
P_gnorm = P_prime * G_prime_scale
# Apply another round of smoothing "relative to G"
P_gnorm = self._smooth_cov(P_gnorm,
group["cov_min"][1],
group["cov_max"][1],
group["cov_pow"][1])
# Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime.
P_prime = P_gnorm / G_prime_scale
#torch.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None)
#In symbols, it solves AX = b, ... X = A^{-1} b
#and assumes A is square upper-triangular (or lower-triangular if upper= False) and does not have zeros on the diagonal.
#
# We want to invert the normalization by C to recover P (if there weren't smoothing), so we want
# P = torch.matmul(C_inv.transpose(2, 3),
# torch.matmul(P_gnorm, C_inv))
# the following is equivalent to: P_temp = torch.matmul(C_inv.transpose(2, 3), P_gnorm),
# where C_inv = C.invert()
P_temp = torch.triangular_solve(P_gnorm, C, upper=False, transpose=True)[0]
# .. now we want to do the same on the other axis, need transpose.
P = torch.triangular_solve(P_temp.transpose(2, 3), C,
upper=False, transpose=True)[0].transpose(2, 3)
if True: # TEMP
C_inv = C.inverse()
P2 = torch.matmul(C_inv.transpose(2, 3),
torch.matmul(P_gnorm, C_inv))
assert (P-P2).norm() <= 0.001 * P.norm()
# Apply a 3rd round of smoothing in the canonical basis.
P_prime = self._smooth_cov(P_prime,
group["cov_min"][2],
group["cov_max"][2],
group["cov_pow"][2])
return P_prime
P = self._smooth_cov(P,
group["cov_min"][2],
group["cov_max"][2],
group["cov_pow"][2])
return P
def _smooth_cov(self,
X: Tensor,