Use a measure of correlation for eigs that can be negative.

This commit is contained in:
Daniel Povey 2022-07-26 13:40:57 +08:00
parent b9696878b4
commit e25ca74955

View File

@ -324,7 +324,11 @@ class TensorDiagnostic(object):
if stats.ndim == 2:
# Matrices, for purposes of measuring eigenvalues. Just compute a dot-product-related
# measure of correlation.
# measure of correlation of eigenvalues.
def norm_diag_mean(x):
return x - torch.eye(x.shape[0], dtype=x.dtype, device=x.device) * x.diag().mean()
stats = norm_diag_mean(stats)
other_stats = norm_diag_mean(other_stats)
correlation = ((stats * other_stats).sum() /
((stats**2).sum() * (other_stats**2).sum() + 1.0e-20).sqrt())
else: