From 3ad042444e29a67b145fd00c19326a0603001480 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 29 Jul 2022 14:38:50 +0800 Subject: [PATCH] More changes to reduce numerical roundoff for dims with zero grad and params. --- .../ASR/pruned_transducer_stateless7/optim.py | 71 ++++++++++++------- 1 file changed, 46 insertions(+), 25 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 1464abf31..945e3ab19 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -894,6 +894,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # # G_prime is the diagonalized grad_cov, G' above, of shape (batch_size, num_blocks, block_size). U_g, G_prime, _ = _svd(grad_cov) + # Recompute G_prime via matrix multiplication, as svd does not seem to produce zeros on + # the singular values in places where it should be exactly zero. This is to keep + # dimensions with zero grad separate from those with nonzero grad. + G_prime = _diag(torch.matmul(U_g.transpose(2,3), torch.matmul(grad_cov, U_g))) + # P_prime is P' above, which represents param_cov in the basis that diagonalizes G_prime. # It is not smoothed yet. P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g)) @@ -913,6 +918,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # CGC = (C^T G' C), it would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C)) # if G were formatted as a diagonal matrix, but G is just the diagonal. CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), C) + # make sure it's exactly symmetric, want to make sure SVD is exact w.r.t. + # dimensions with zero grad and zero parameters. + CGC = 0.5 * (CGC + CGC.transpose(2, 3)) U, S, _ = _svd(CGC) # Need SVD to compute CGC^{0.5} @@ -921,7 +929,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # we have Z'^{-1} = X S^{0.5} X^T where X = C^{-T} U X = torch.triangular_solve(U, C, upper=False, transpose=True)[0] Z_prime_inv = torch.matmul(X * S.sqrt().unsqueeze(-2), X.transpose(2, 3)) - + # make sure it's exactly symmetric, want to make sure SVD is exact w.r.t. + # dimensions with zero grad and zero parameters. + Z_prime_inv = 0.5 * (Z_prime_inv + Z_prime_inv.transpose(2, 3)) if True: def _check_similar(x, y, name): @@ -1517,31 +1527,42 @@ def _mean(x: Tensor, -def _svd(x: Tensor): +def _svd(x: Tensor, recursion_depth: int = 0): # Wrapper for torch svd that catches errors and re-tries (to address bugs in # some versions of torch SVD for CUDA) - randU = None - for i in range(4): - try: - U, S, V = x.svd() - s = U.sum() - if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01): - raise RuntimeError(f"svd failed (or random test), sum={s}") - if randU is not None: - U = torch.matmul(randU.t(), U) - return U, S, V # success - except: - logging.warning(f"svd failed: i={i}, x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely " - f"error in SVD. Will try again after random change.") - this_x = x - while this_x.ndim > 2: - this_x = this_x[0] # get rid of any batch dims - U, S, V = torch.randn_like(this_x).svd() - x = torch.matmul(U, x) - if randU is None: - randU = U - else: - randU = torch.matmul(U, randU) + xsum = x.sum() + if not (xsum - xsum == 0): + raise RuntimeError("svd on inf or nan input") + try: + U, S, V = x.svd() + s = U.sum() + if s - s != 0 or (__name__ == '__main__' and random.random() < 0.01): + raise RuntimeError(f"svd failed (or random test), sum={s}") + return U, S, V # success + except: + if recursion_depth < 2: + logging.warning(f"svd failed: recursion_depth={recursion_depth}, " + f"x.shape={tuple(x.shape)}, x.sum()={x.sum()}: likely " + f"error in SVD. Will try again after random permutation of rows") + # randomly permute the row indexes of the matrices, and retry, hoping this + # fixes the issue + n = x.shape[-2] + randperm = torch.randperm(n, device=x.device) + inv_randperm = torch.zeros(n, device=x.device, dtype=torch.int64) + inv_randperm[randperm] = torch.arange(n, device=x.device) + x = torch.index_select(x, dim=-2, index=randperm) + U,S,V = _svd(x, recursion_depth + 1) + return torch.index_select(U, dim=-2, index=inv_randperm), S, V + elif recursion_depth < 4: + logging.warning(f"svd failed after {recursion_depth} tries: x.shape={tuple(x.shape)}, x.sum()={x.sum()}. Will try orthogonal transformation") + Urand, _, _ = torch.randn(x.shape[-2], x.shape[-1], device=x.device, + dtype=x.dtype).svd() + U, S, V = _svd(torch.matmul(Urand, x), + recursion_depth + 1) + return torch.matmul(Urand.t(), U), S, V + else: + raise RuntimeError(f"svd failed after {recursion_depth} tries") + @@ -2182,7 +2203,7 @@ def _test_eve_cain(): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - hidden_dim = 768 + hidden_dim = 400 m = torch.nn.Sequential(Linear(E, hidden_dim), torch.nn.PReLU(), Linear(hidden_dim, hidden_dim),