Fix diagnostics bug

This commit is contained in:
Daniel Povey 2022-11-30 16:52:21 +08:00
parent b79a794706
commit 2969eb5467

View File

@ -420,12 +420,11 @@ class ScalarDiagnostic(object):
bin_gradsq = torch.zeros(num_bins)
bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq)
bin_abs_grad = torch.zeros(num_bins)
bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_gradsq)
bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad)
avg_grad = (bin_grad / bin_counts)
avg_grad_stddev = (bin_gradsq / bin_counts).sqrt()
bin_boundary_counts = torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin
bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts)
# boundaries are the "x" values between the bins, e.g. corresponding to the