mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix to diagnostics.py
This commit is contained in:
parent
3faa4ffa3f
commit
183fc7a76d
@ -376,13 +376,11 @@ class ScalarDiagnostic(object):
|
|||||||
# sum_gradsq is for getting error bars.
|
# sum_gradsq is for getting error bars.
|
||||||
self.sum_gradsq = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
|
self.sum_gradsq = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
|
||||||
self.sum_abs_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
|
self.sum_abs_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
|
||||||
print("tick scale:", self.tick_scale)
|
|
||||||
|
|
||||||
# this will round down.
|
# this will round down.
|
||||||
x = (x / self.tick_scale).to(torch.long)
|
x = (x / self.tick_scale).to(torch.long)
|
||||||
x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1)
|
x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1)
|
||||||
x = x + num_ticks_per_side
|
x = x + num_ticks_per_side
|
||||||
print("x indexes: ", x)
|
|
||||||
|
|
||||||
self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x))
|
self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x))
|
||||||
self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))
|
self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user