From 2969eb5467a2ee59899a68618316418626ed32d5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Nov 2022 16:52:21 +0800 Subject: [PATCH] Fix diagnostics bug --- icefall/diagnostics.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 036016a46..2a937c6b8 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -420,12 +420,11 @@ class ScalarDiagnostic(object): bin_gradsq = torch.zeros(num_bins) bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq) bin_abs_grad = torch.zeros(num_bins) - bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_gradsq) + bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad) avg_grad = (bin_grad / bin_counts) avg_grad_stddev = (bin_gradsq / bin_counts).sqrt() - bin_boundary_counts = torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts) # boundaries are the "x" values between the bins, e.g. corresponding to the