Prevent crash due to error in C_diff

This commit is contained in:
Daniel Povey 2022-06-17 12:35:55 +08:00
parent 6c499dcd66
commit 2fe4af8c99

View File

@ -709,7 +709,9 @@ class NeutralGradient(Optimizer):
C_diff = C_check - C
# Roundoff can cause significant differences, so use a fairly large
# threshold of 0.001. We may increase this later or even remove the check.
assert C_diff.abs().mean() < 0.01 * C.diag().mean()
if not C_diff.abs().mean() < 0.01 * C.diag().mean():
print("Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C.diag().mean()}")
return P