Fix to diagnostics.py

This commit is contained in:
Daniel Povey 2022-12-03 00:22:30 +08:00
parent 3faa4ffa3f
commit 183fc7a76d

View File

@ -376,13 +376,11 @@ class ScalarDiagnostic(object):
# sum_gradsq is for getting error bars.
self.sum_gradsq = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
self.sum_abs_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
print("tick scale:", self.tick_scale)
# this will round down.
x = (x / self.tick_scale).to(torch.long)
x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1)
x = x + num_ticks_per_side
print("x indexes: ", x)
self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x))
self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))