mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Improve roundoff properties of algorithm for getting C, use randomized testing.
This commit is contained in:
parent
d729e970e1
commit
e597b68160
@ -439,27 +439,34 @@ class NeutralGradient(Optimizer):
|
||||
#
|
||||
# You can verify fairly easily that P G P == C: multiply on left and right
|
||||
# by C^{-0.5} on each side and you can see that both sides equal I.
|
||||
#
|
||||
# We can reduce the amount of roundoff by, after decomposing C
|
||||
# with SVD as (U S U^T), writing C^{0.5} = U Q = Q^T U^T, with Q = S^0.5 U^T.
|
||||
#
|
||||
# Then we can write P as follows:
|
||||
# P = Q^T U^T (U Q G Q^T U^T)^{-0.5} U Q,
|
||||
# and we easily show that for orthogonal U, (U X U^T)^{-0.5} = U X^{-0.5} U^T, so we
|
||||
# can do this and cancel U^T U = U U^T = I, giving:
|
||||
# P = Q^T (Q G Q^T)^{-0.5} Q.
|
||||
|
||||
# C_sqrt is symmetric.
|
||||
C_sqrt = torch.matmul(U * S.sqrt(), U.t()) # U . S.sqrt().diag() . U.t().
|
||||
if True:
|
||||
C_check = torch.matmul(C_sqrt, C_sqrt)
|
||||
assert(torch.allclose(C_check, C, atol=1.0e-03*C.diag().mean()))
|
||||
# because C is symmetric, C == U S U^T, we can ignore V.
|
||||
Q = (U * S.sqrt()).t()
|
||||
X = torch.matmul(Q, torch.matmul(grad_cov, Q.t()))
|
||||
|
||||
prod = torch.matmul(C_sqrt, torch.matmul(grad_cov, C_sqrt))
|
||||
# we need prod^{-0.5}. First make sure prod is perfectly symmetric so we
|
||||
# can do svd.
|
||||
# we will normalize away any scalar factor so don't bother multiplying by 0.5
|
||||
prod = (prod + prod.t())
|
||||
U2, S2, _ = X.svd()
|
||||
# Because X^{-0.5} = U2 S2^{-0.5} U2^T, we have
|
||||
# P = Q^T U2 S2^{-0.5} U2^T Q
|
||||
# = Y Y^T, where
|
||||
# Y = Q^T U2 S2^{-0.25}
|
||||
|
||||
U, S, V = prod.svd()
|
||||
#print("S[2] = " , S)
|
||||
prod_inv_sqrt = torch.matmul(U * (S + 1.0e-30) ** -0.5, U.t())
|
||||
S2_pow = (S2 + 1.0e-35) ** -0.25
|
||||
Y = torch.matmul(Q.t(), U2) * S2_pow
|
||||
|
||||
P = torch.matmul(Y, Y.t())
|
||||
|
||||
P = torch.matmul(torch.matmul(C_sqrt, prod_inv_sqrt), C_sqrt)
|
||||
P = P * (1.0 / P.diag().mean()) # normalize size of P.
|
||||
|
||||
if True:
|
||||
if random.random() < 0.1:
|
||||
#U, S, V = P.svd()
|
||||
#if random.random() < 0.1:
|
||||
# print(f"Min,max eig of P: {S.min().item()},{S.max().item()}")
|
||||
@ -472,7 +479,7 @@ class NeutralGradient(Optimizer):
|
||||
# C_check should equal C, up to a constant. First normalize both
|
||||
C_diff = (C_check / C_check.diag().mean()) - (C / C.diag().mean())
|
||||
# actually, roundoff can cause significant differences.
|
||||
assert C_diff.abs().mean() < 0.1
|
||||
assert C_diff.abs().mean() < 0.001
|
||||
|
||||
return P
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user