diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 96c085541..15ec28b3b 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -254,7 +254,7 @@ class TensorDiagnostic(object): "Error getting eigenvalues, trying another method." ) eigs, _ = torch.eig(stats) - stats = eigs.abs().sqrt() + stats = eigs.norm(dim=1).sqrt() # sqrt so it reflects data magnitude, like stddev- not variance if stats_type in [ "rms", "stddev" ]: