Fix diagnostics.py re backoff for eigs

This commit is contained in:
Daniel Povey 2022-12-22 23:00:44 +08:00
parent e5b047a814
commit 2e6610af5e

View File

@ -254,7 +254,7 @@ class TensorDiagnostic(object):
"Error getting eigenvalues, trying another method."
)
eigs, _ = torch.eig(stats)
stats = eigs.abs().sqrt()
stats = eigs.norm(dim=1).sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance
if stats_type in [ "rms", "stddev" ]: