From 923468b8af313dda0dd56e342b54d12da9e59b6d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 8 Jul 2022 09:00:12 +0800 Subject: [PATCH] Deal with SVD failure better. --- .../ASR/pruned_transducer_stateless7/optim.py | 177 ++++++++++-------- 1 file changed, 104 insertions(+), 73 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index cf61b1307..689c2fc25 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -254,7 +254,6 @@ class LearnedGradient(Optimizer): if step % lr_est_period == 0 and step > 0: self._update_lrs(group, p, state) if step % (lr_est_period * diagonalize_period) == 0 and step > 0: - logging.info("Diagonalizing") self._diagonalize_lrs(group, p, state) self._zero_exp_avg_sq(state) self._step(group, p, grad, state) @@ -443,6 +442,7 @@ class LearnedGradient(Optimizer): # Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T # torch.solve(A, B) returns A^{-1} B. try: + # not all versions of Torch have this. Q_invt_proj_grad = torch.linalg.solve(Q.t(), proj_grad) except: # torch.solve has args the other way round and also returns LU. @@ -541,84 +541,116 @@ class LearnedGradient(Optimizer): if size == 1: continue Q = state[f"Q_{dim}"] - - if True: - # This block limits the singular values of Q. + Q_orig = Q.clone() # TEMP + for i in range(4): try: - U,S,V = Q.svd() + self._diagonalize_lr(group, p, dim, state) + break # break from the loop over i: nothing was raised, so it succeeded. except: - # if SVD fails for some reason we can rotate Q by something arbitrary on the - # left, as we're going to rotate here anyway, and try again. - logging.warn("Error doing SVD on Q. Trying again after rotation.") - U,_,_ = torch.randn_like(Q).svd() # Create random orthogonal matrix. - Q[:] = torch.matmul(U, Q) - U,S,V = Q.svd() - S_new = S.clamp(lr_mat_min, lr_mat_max) - S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0 - S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more. - # Reconstruct Q with the modified S. - Q[:] = torch.matmul(U * S_new, V.t()) - if random.random() < 0.03: - subsample = max(1, S.numel() // 20) - logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}") + logging.warning(f"self._diagonalize_lr raised: i={i}, Q.sum()={Q.sum()}: likely " + f"error in SVD. Will try again after random change.") + # Likely torch SVD failed, its implementation for GPU has bugs for some torch versions, + # e.g. 1.7.1. Left-multiply Q by an arbitrary orthogonal matrix, and we'll try again. + # (Q is going to be diagonalized on that axis anyway, so this shouldn't meaningfully + # affect the output). + U, S, V = torch.randn_like(Q_orig).svd() + # U is a random orthogonal matrix. + Q[:] = torch.matmul(U, Q_orig) - if True: - # This block does the actual diagonalization. - # Suppose the actual parameter matrix p is M, of shape (-1, size), where - # the -1 represents all other tensor dims treated as a batch dimension. - # M_grad is the same shape as M. We could write a pseudo-loss as - # loss = tr(M_grad^T M) - # Because we can decompose M as M == N Q, we can write: - # loss = tr(M_grad^T N Q) = tr(Q M_grad^T N), - # so we can write this at tr(N_grad^T N), - # where N_grad == (Q M_grad^T)^T = M_grad Q^T. - # Now, - # grad_cov == M_grad^T M_grad, - # decaying-averaged over minibatches; this is of shape (size,size). - # Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N - # (which is what we want to diagonalize), as:: - # N_grad_cov = N_grad^T N_grad - # = Q M_grad^T M_grad Q^T - # = Q grad_cov Q^T - # (note: this makes sense because the 1st index of Q is the diagonalized - # index). - def dispersion(v): - # a ratio that, if the eigenvalues are more dispersed, it will be larger. - return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item() - grad_cov = state[f"grad_cov_{dim}"] - N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t())) - N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric - U, S, V = N_grad_cov.svd() - logging.info(f"Diagonalizing, dispersion changed from {dispersion(N_grad_cov.diag())} to {dispersion(S)}") + def _diagonalize_lr(self, + group: dict, + p: Tensor, + dim: int, + state: dict) -> None: + """ + This piece of code was broken out of _diagonalize_lrs so that we can more easily + recover from errors in torch SVD, which can be quite common in some torch versions; + e.g. in torch.1.7.1+cu101, CUDA SVD generates NaNs for some very reasonably-sized + inputs, like I have a 200 by 200 "bad case", whose min and max elements are + -0.2766 and 0.3444 respectively, that generates NaN's in its U and V factors. + Having closely-spaced singular values seems to present a problem. + """ + lr_mat_min = group["lr_mat_min"] + lr_mat_max = group["lr_mat_max"] - # N_grad_cov is SPD, so - # N_grad_cov = U S U^T. + Q = state[f"Q_{dim}"] + if True: + # This block limits the singular values of Q. + U,S,V = Q.svd() + S_new = S.clamp(lr_mat_min, lr_mat_max) + S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0 + S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more. + # Reconstruct Q with the modified S. + Q[:] = torch.matmul(U * S_new, V.t()) + if random.random() < 0.03: + subsample = max(1, S.numel() // 20) + logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}") - # Now, we can diagonalize N_grad_cov with: - # U^T N_grad_cov U == S. - # N_grad_cov is a sum of N_grad^T N_grad. - # We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal. - # The linearized pseudo-loss can be written as tr(N_grad^T N_grad). - # This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I, - # - # which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted - # as tr(hat_N_grad hat_N), where: - # hat_N_grad = N_grad U - # hat_N = N U - # (hat_N means \hat{N}, or N with a hat on it). - # So if we interpret hat_N = N U, the gradient covariance w.r.t. - # hat_N will be diagonalized. We also modify Q to hat_Q when - # we modify hat_N, to keep the product M unchanged: - # M = N Q = N U U^T Q = hat_N hat_Q - # This can be done by setting - # hat_Q := U^T Q (eq.10) - # - # This is the only thing we have to do, as N is implicit - # and not materialized at this point. - Q[:] = torch.matmul(U.t(), Q) + if True: + # This block does the actual diagonalization. + # Suppose the actual parameter matrix p is M, of shape (-1, size), where + # the -1 represents all other tensor dims treated as a batch dimension. + # M_grad is the same shape as M. We could write a pseudo-loss as + # loss = tr(M_grad^T M) + # Because we can decompose M as M == N Q, we can write: + # loss = tr(M_grad^T N Q) = tr(Q M_grad^T N), + # so we can write this at tr(N_grad^T N), + # where N_grad == (Q M_grad^T)^T = M_grad Q^T. + # Now, + # grad_cov == M_grad^T M_grad, + # decaying-averaged over minibatches; this is of shape (size,size). + # Using N_grad = M_grad Q^T, we can write the gradient covariance w.r.t. N + # (which is what we want to diagonalize), as:: + # N_grad_cov = N_grad^T N_grad + # = Q M_grad^T M_grad Q^T + # = Q grad_cov Q^T + # (note: this makes sense because the 1st index of Q is the diagonalized + # index). + + def dispersion(v): + # a ratio that, if the eigenvalues are more dispersed, it will be larger. + return '%.3e' % ((v**2).mean() / (v.mean() ** 2)).item() + + grad_cov = state[f"grad_cov_{dim}"] + N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t())) + N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric + U, S, V = N_grad_cov.svd() + if random.random() < 0.1: + logging.info(f"Diagonalizing, shape={tuple(p.shape)}, dim={dim}, dispersion " + f"changed from {dispersion(N_grad_cov.diag())} to {dispersion(S)}") + + # N_grad_cov is SPD, so + # N_grad_cov = U S U^T. + + # Now, we can diagonalize N_grad_cov with: + # U^T N_grad_cov U == S. + # N_grad_cov is a sum of N_grad^T N_grad. + # We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal. + # The linearized pseudo-loss can be written as tr(N_grad^T N_grad). + # This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I, + # + # which we can rearrange as tr(U^T N_grad^T N U). This can be interpreted + # as tr(hat_N_grad hat_N), where: + # hat_N_grad = N_grad U + # hat_N = N U + # (hat_N means \hat{N}, or N with a hat on it). + # So if we interpret hat_N = N U, the gradient covariance w.r.t. + # hat_N will be diagonalized. We also modify Q to hat_Q when + # we modify hat_N, to keep the product M unchanged: + # M = N Q = N U U^T Q = hat_N hat_Q + # This can be done by setting + # hat_Q := U^T Q (eq.10) + # + # This is the only thing we have to do, as N is implicit + # and not materialized at this point. + Q[:] = torch.matmul(U.t(), Q) + + s = Q[:].sum() + if not s == s: # if a NaN was generated: + raise RuntimeError('Likely SVD failure') def _get_matrix_var(self, @@ -1880,8 +1912,7 @@ def _test_eve_cain(): avg_loss = loss.item() else: avg_loss = 0.95 * avg_loss + 0.05 * loss.item() - #if n == 0 and epoch % 5 == 0: - if True: + if n == 0 and epoch % 5 == 0: norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()