From e597b6816050c137825575233b4e5aeddd8e81fb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 14 Jun 2022 12:59:41 +0800 Subject: [PATCH] Improve roundoff properties of algorithm for getting C, use randomized testing. --- .../ASR/pruned_transducer_stateless7/optim.py | 39 +++++++++++-------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 87e8f09ec..fc31a8bf9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -439,27 +439,34 @@ class NeutralGradient(Optimizer): # # You can verify fairly easily that P G P == C: multiply on left and right # by C^{-0.5} on each side and you can see that both sides equal I. + # + # We can reduce the amount of roundoff by, after decomposing C + # with SVD as (U S U^T), writing C^{0.5} = U Q = Q^T U^T, with Q = S^0.5 U^T. + # + # Then we can write P as follows: + # P = Q^T U^T (U Q G Q^T U^T)^{-0.5} U Q, + # and we easily show that for orthogonal U, (U X U^T)^{-0.5} = U X^{-0.5} U^T, so we + # can do this and cancel U^T U = U U^T = I, giving: + # P = Q^T (Q G Q^T)^{-0.5} Q. - # C_sqrt is symmetric. - C_sqrt = torch.matmul(U * S.sqrt(), U.t()) # U . S.sqrt().diag() . U.t(). - if True: - C_check = torch.matmul(C_sqrt, C_sqrt) - assert(torch.allclose(C_check, C, atol=1.0e-03*C.diag().mean())) + # because C is symmetric, C == U S U^T, we can ignore V. + Q = (U * S.sqrt()).t() + X = torch.matmul(Q, torch.matmul(grad_cov, Q.t())) - prod = torch.matmul(C_sqrt, torch.matmul(grad_cov, C_sqrt)) - # we need prod^{-0.5}. First make sure prod is perfectly symmetric so we - # can do svd. - # we will normalize away any scalar factor so don't bother multiplying by 0.5 - prod = (prod + prod.t()) + U2, S2, _ = X.svd() + # Because X^{-0.5} = U2 S2^{-0.5} U2^T, we have + # P = Q^T U2 S2^{-0.5} U2^T Q + # = Y Y^T, where + # Y = Q^T U2 S2^{-0.25} - U, S, V = prod.svd() - #print("S[2] = " , S) - prod_inv_sqrt = torch.matmul(U * (S + 1.0e-30) ** -0.5, U.t()) + S2_pow = (S2 + 1.0e-35) ** -0.25 + Y = torch.matmul(Q.t(), U2) * S2_pow + + P = torch.matmul(Y, Y.t()) - P = torch.matmul(torch.matmul(C_sqrt, prod_inv_sqrt), C_sqrt) P = P * (1.0 / P.diag().mean()) # normalize size of P. - if True: + if random.random() < 0.1: #U, S, V = P.svd() #if random.random() < 0.1: # print(f"Min,max eig of P: {S.min().item()},{S.max().item()}") @@ -472,7 +479,7 @@ class NeutralGradient(Optimizer): # C_check should equal C, up to a constant. First normalize both C_diff = (C_check / C_check.diag().mean()) - (C / C.diag().mean()) # actually, roundoff can cause significant differences. - assert C_diff.abs().mean() < 0.1 + assert C_diff.abs().mean() < 0.001 return P