Update diagnostics (#260)

* update diagnostics.py
This commit is contained in:
Mingshuang Luo 2022-03-30 14:52:55 +08:00 committed by GitHub
parent 395a3f952b
commit f686635b54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -135,8 +135,13 @@ def get_diagnostics_for_dim(
return "" return ""
count = sum(counts) count = sum(counts)
stats = stats / count stats = stats / count
stats, _ = torch.symeig(stats) try:
stats = stats.abs().sqrt() eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print("Error getting eigenvalues, trying another method.")
eigs, _ = torch.eigs(stats)
stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance # sqrt so it reflects data magnitude, like stddev- not variance
elif sizes_same: elif sizes_same:
stats = torch.stack(stats).sum(dim=0) stats = torch.stack(stats).sum(dim=0)