diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 91580f986..2367f7171 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -376,13 +376,11 @@ class ScalarDiagnostic(object): # sum_gradsq is for getting error bars. self.sum_gradsq = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device) self.sum_abs_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device) - print("tick scale:", self.tick_scale) # this will round down. x = (x / self.tick_scale).to(torch.long) x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1) x = x + num_ticks_per_side - print("x indexes: ", x) self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x)) self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))