mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Avoid error if svd fails
This commit is contained in:
parent
2233c852fd
commit
f327407308
@ -913,13 +913,14 @@ class NeutralGradient(Optimizer):
|
|||||||
P = torch.matmul(Y, Y.t())
|
P = torch.matmul(Y, Y.t())
|
||||||
|
|
||||||
if random.random() < 1.0: #0.025:
|
if random.random() < 1.0: #0.025:
|
||||||
|
|
||||||
# TEMP:
|
|
||||||
_,s,_ = P.svd()
|
|
||||||
print(f"Min,max eig of P: {s.min()},{s.max()}")
|
|
||||||
|
|
||||||
# TODO: remove this testing code.
|
# TODO: remove this testing code.
|
||||||
assert (P - P.t()).abs().mean() < 0.01 # make sure symmetric.
|
assert (P - P.t()).abs().mean() < 0.01 # make sure symmetric.
|
||||||
|
try:
|
||||||
|
P = 0.5 * (P + P.t())
|
||||||
|
_,s,_ = P.svd()
|
||||||
|
print(f"Min,max eig of P: {s.min()},{s.max()}")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
# testing... note, this is only true modulo "eps"
|
# testing... note, this is only true modulo "eps"
|
||||||
C_check = torch.matmul(torch.matmul(P, G), P)
|
C_check = torch.matmul(torch.matmul(P, G), P)
|
||||||
C_smoothed = torch.matmul(Q.t(), Q)
|
C_smoothed = torch.matmul(Q.t(), Q)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user