Avoid error if svd fails

This commit is contained in:
Daniel Povey 2022-06-24 13:20:20 +08:00
parent 2233c852fd
commit f327407308

View File

@ -913,13 +913,14 @@ class NeutralGradient(Optimizer):
P = torch.matmul(Y, Y.t()) P = torch.matmul(Y, Y.t())
if random.random() < 1.0: #0.025: if random.random() < 1.0: #0.025:
# TEMP:
_,s,_ = P.svd()
print(f"Min,max eig of P: {s.min()},{s.max()}")
# TODO: remove this testing code. # TODO: remove this testing code.
assert (P - P.t()).abs().mean() < 0.01 # make sure symmetric. assert (P - P.t()).abs().mean() < 0.01 # make sure symmetric.
try:
P = 0.5 * (P + P.t())
_,s,_ = P.svd()
print(f"Min,max eig of P: {s.min()},{s.max()}")
except:
pass
# testing... note, this is only true modulo "eps" # testing... note, this is only true modulo "eps"
C_check = torch.matmul(torch.matmul(P, G), P) C_check = torch.matmul(torch.matmul(P, G), P)
C_smoothed = torch.matmul(Q.t(), Q) C_smoothed = torch.matmul(Q.t(), Q)