Improve roundoff properties of algorithm for getting C, use randomized testing.

This commit is contained in:
Daniel Povey 2022-06-14 12:59:41 +08:00
parent d729e970e1
commit e597b68160

View File

@ -439,27 +439,34 @@ class NeutralGradient(Optimizer):
#
# You can verify fairly easily that P G P == C: multiply on left and right
# by C^{-0.5} on each side and you can see that both sides equal I.
#
# We can reduce the amount of roundoff by, after decomposing C
# with SVD as (U S U^T), writing C^{0.5} = U Q = Q^T U^T, with Q = S^0.5 U^T.
#
# Then we can write P as follows:
# P = Q^T U^T (U Q G Q^T U^T)^{-0.5} U Q,
# and we easily show that for orthogonal U, (U X U^T)^{-0.5} = U X^{-0.5} U^T, so we
# can do this and cancel U^T U = U U^T = I, giving:
# P = Q^T (Q G Q^T)^{-0.5} Q.
# C_sqrt is symmetric.
C_sqrt = torch.matmul(U * S.sqrt(), U.t()) # U . S.sqrt().diag() . U.t().
if True:
C_check = torch.matmul(C_sqrt, C_sqrt)
assert(torch.allclose(C_check, C, atol=1.0e-03*C.diag().mean()))
# because C is symmetric, C == U S U^T, we can ignore V.
Q = (U * S.sqrt()).t()
X = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
prod = torch.matmul(C_sqrt, torch.matmul(grad_cov, C_sqrt))
# we need prod^{-0.5}. First make sure prod is perfectly symmetric so we
# can do svd.
# we will normalize away any scalar factor so don't bother multiplying by 0.5
prod = (prod + prod.t())
U2, S2, _ = X.svd()
# Because X^{-0.5} = U2 S2^{-0.5} U2^T, we have
# P = Q^T U2 S2^{-0.5} U2^T Q
# = Y Y^T, where
# Y = Q^T U2 S2^{-0.25}
U, S, V = prod.svd()
#print("S[2] = " , S)
prod_inv_sqrt = torch.matmul(U * (S + 1.0e-30) ** -0.5, U.t())
S2_pow = (S2 + 1.0e-35) ** -0.25
Y = torch.matmul(Q.t(), U2) * S2_pow
P = torch.matmul(Y, Y.t())
P = torch.matmul(torch.matmul(C_sqrt, prod_inv_sqrt), C_sqrt)
P = P * (1.0 / P.diag().mean()) # normalize size of P.
if True:
if random.random() < 0.1:
#U, S, V = P.svd()
#if random.random() < 0.1:
# print(f"Min,max eig of P: {S.min().item()},{S.max().item()}")
@ -472,7 +479,7 @@ class NeutralGradient(Optimizer):
# C_check should equal C, up to a constant. First normalize both
C_diff = (C_check / C_check.diag().mean()) - (C / C.diag().mean())
# actually, roundoff can cause significant differences.
assert C_diff.abs().mean() < 0.1
assert C_diff.abs().mean() < 0.001
return P