Deal with SVD failure better.

This commit is contained in:
Daniel Povey 2022-07-08 09:00:12 +08:00
parent 97feb8a3ec
commit 923468b8af

View File

@ -254,7 +254,6 @@ class LearnedGradient(Optimizer):
if step % lr_est_period == 0 and step > 0:
self._update_lrs(group, p, state)
if step % (lr_est_period * diagonalize_period) == 0 and step > 0:
logging.info("Diagonalizing")
self._diagonalize_lrs(group, p, state)
self._zero_exp_avg_sq(state)
self._step(group, p, grad, state)
@ -443,6 +442,7 @@ class LearnedGradient(Optimizer):
# Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T
# torch.solve(A, B) returns A^{-1} B.
try:
# not all versions of Torch have this.
Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad)
except:
# torch.solve has args the other way round and also returns LU.
@ -541,84 +541,116 @@ class LearnedGradient(Optimizer):
if size == 1:
continue
Q = state[f"Q_{dim}"]
if True:
# This block limits the singular values of Q.
Q_orig = Q.clone() # TEMP
for i in range(4):
try:
U,S,V = Q.svd()
self._diagonalize_lr(group, p, dim, state)
break # break from the loop over i: nothing was raised, so it succeeded.
except:
# if SVD fails for some reason we can rotate Q by something arbitrary on the
# left, as we're going to rotate here anyway, and try again.
logging.warn("Error doing SVD on Q. Trying again after rotation.")
U,_,_ = torch.randn_like(Q).svd() # Create random orthogonal matrix.
Q[:] = torch.matmul(U, Q)
U,S,V = Q.svd()
S_new = S.clamp(lr_mat_min, lr_mat_max)
S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0
S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more.
# Reconstruct Q with the modified S.
Q[:] = torch.matmul(U * S_new, V.t())
if random.random() < 0.03:
subsample = max(1, S.numel() // 20)
logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}")
logging.warning(f"self._diagonalize_lr raised: i={i}, Q.sum()={Q.sum()}: likely "
f"error in SVD. Will try again after random change.")
# Likely torch SVD failed, its implementation for GPU has bugs for some torch versions,
# e.g. 1.7.1. Left-multiply Q by an arbitrary orthogonal matrix, and we'll try again.
# (Q is going to be diagonalized on that axis anyway, so this shouldn't meaningfully
# affect the output).
U, S, V = torch.randn_like(Q_orig).svd()
# U is a random orthogonal matrix.
Q[:] = torch.matmul(U, Q_orig)
if True:
# This block does the actual diagonalization.
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
# the -1 represents all other tensor dims treated as a batch dimension.
# M_grad is the same shape as M. We could write a pseudo-loss as
# loss = tr(M_grad^T M)
# Because we can decompose M as M == N Q, we can write:
# loss = tr(M_grad^T N Q) = tr(Q M_grad^T N),
# so we can write this at tr(N_grad^T N),
# where N_grad == (Q M_grad^T)^T = M_grad Q^T.
# Now,
# grad_cov == M_grad^T M_grad,
# decaying-averaged over minibatches; this is of shape (size,size).
# Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N
# (which is what we want to diagonalize), as::
# N_grad_cov = N_grad^T N_grad
# = Q M_grad^T M_grad Q^T
# = Q grad_cov Q^T
# (note: this makes sense because the 1st index of Q is the diagonalized
# index).
def dispersion(v):
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
grad_cov = state[f"grad_cov_{dim}"]
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
U, S, V = N_grad_cov.svd()
logging.info(f"Diagonalizing, dispersion changed from {dispersion(N_grad_cov.diag())} to {dispersion(S)}")
def _diagonalize_lr(self,
group: dict,
p: Tensor,
dim: int,
state: dict) -> None:
"""
This piece of code was broken out of _diagonalize_lrs so that we can more easily
recover from errors in torch SVD, which can be quite common in some torch versions;
e.g. in torch.1.7.1+cu101, CUDA SVD generates NaNs for some very reasonably-sized
inputs, like I have a 200 by 200 "bad case", whose min and max elements are
-0.2766 and 0.3444 respectively, that generates NaN's in its U and V factors.
Having closely-spaced singular values seems to present a problem.
"""
lr_mat_min = group["lr_mat_min"]
lr_mat_max = group["lr_mat_max"]
# N_grad_cov is SPD, so
# N_grad_cov = U S U^T.
Q = state[f"Q_{dim}"]
if True:
# This block limits the singular values of Q.
U,S,V = Q.svd()
S_new = S.clamp(lr_mat_min, lr_mat_max)
S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0
S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more.
# Reconstruct Q with the modified S.
Q[:] = torch.matmul(U * S_new, V.t())
if random.random() < 0.03:
subsample = max(1, S.numel() // 20)
logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}")
# Now, we can diagonalize N_grad_cov with:
# U^T N_grad_cov U == S.
# N_grad_cov is a sum of N_grad^T N_grad.
# We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal.
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
#
# which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted
# as tr(hat_N_grad hat_N), where:
# hat_N_grad = N_grad U
# hat_N = N U
# (hat_N means \hat{N}, or N with a hat on it).
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
# hat_N will be diagonalized. We also modify Q to hat_Q when
# we modify hat_N, to keep the product M unchanged:
# M = N Q = N U U^T Q = hat_N hat_Q
# This can be done by setting
# hat_Q := U^T Q (eq.10)
#
# This is the only thing we have to do, as N is implicit
# and not materialized at this point.
Q[:] = torch.matmul(U.t(), Q)
if True:
# This block does the actual diagonalization.
# Suppose the actual parameter matrix p is M, of shape (-1, size), where
# the -1 represents all other tensor dims treated as a batch dimension.
# M_grad is the same shape as M. We could write a pseudo-loss as
# loss = tr(M_grad^T M)
# Because we can decompose M as M == N Q, we can write:
# loss = tr(M_grad^T N Q) = tr(Q M_grad^T N),
# so we can write this at tr(N_grad^T N),
# where N_grad == (Q M_grad^T)^T = M_grad Q^T.
# Now,
# grad_cov == M_grad^T M_grad,
# decaying-averaged over minibatches; this is of shape (size,size).
# Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N
# (which is what we want to diagonalize), as::
# N_grad_cov = N_grad^T N_grad
# = Q M_grad^T M_grad Q^T
# = Q grad_cov Q^T
# (note: this makes sense because the 1st index of Q is the diagonalized
# index).
def dispersion(v):
# a ratio that, if the eigenvalues are more dispersed, it will be larger.
return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item()
grad_cov = state[f"grad_cov_{dim}"]
N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric
U, S, V = N_grad_cov.svd()
if random.random() < 0.1:
logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion "
f"changed from {dispersion(N_grad_cov.diag())} to {dispersion(S)}")
# N_grad_cov is SPD, so
# N_grad_cov = U S U^T.
# Now, we can diagonalize N_grad_cov with:
# U^T N_grad_cov U == S.
# N_grad_cov is a sum of N_grad^T N_grad.
# We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal.
# The linearized pseudo-loss can be written as tr(N_grad^T N_grad).
# This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I,
#
# which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted
# as tr(hat_N_grad hat_N), where:
# hat_N_grad = N_grad U
# hat_N = N U
# (hat_N means \hat{N}, or N with a hat on it).
# So if we interpret hat_N = N U, the gradient covariance w.r.t.
# hat_N will be diagonalized. We also modify Q to hat_Q when
# we modify hat_N, to keep the product M unchanged:
# M = N Q = N U U^T Q = hat_N hat_Q
# This can be done by setting
# hat_Q := U^T Q (eq.10)
#
# This is the only thing we have to do, as N is implicit
# and not materialized at this point.
Q[:] = torch.matmul(U.t(), Q)
s = Q[:].sum()
if not s == s: # if a NaN was generated:
raise RuntimeError('Likely SVD failure')
def _get_matrix_var(self,
@ -1880,8 +1912,7 @@ def _test_eve_cain():
avg_loss = loss.item()
else:
avg_loss = 0.95 * avg_loss + 0.05 * loss.item()
#if n == 0 and epoch % 5 == 0:
if True:
if n == 0 and epoch % 5 == 0:
norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()