mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix diagnostics bug
This commit is contained in:
parent
b79a794706
commit
2969eb5467
@ -420,12 +420,11 @@ class ScalarDiagnostic(object):
|
||||
bin_gradsq = torch.zeros(num_bins)
|
||||
bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq)
|
||||
bin_abs_grad = torch.zeros(num_bins)
|
||||
bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_gradsq)
|
||||
bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad)
|
||||
|
||||
avg_grad = (bin_grad / bin_counts)
|
||||
avg_grad_stddev = (bin_gradsq / bin_counts).sqrt()
|
||||
|
||||
|
||||
bin_boundary_counts = torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin
|
||||
bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts)
|
||||
# boundaries are the "x" values between the bins, e.g. corresponding to the
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user