Fix bug in diagnostics

This commit is contained in:
Daniel Povey 2022-03-04 13:08:56 +08:00
parent 3d9ddc2016
commit 503f8d521c

View File

@ -56,6 +56,7 @@ def get_sum_abs_stats(x: Tensor, dim: int,
x = (x > 0).to(dtype=torch.float)
orig_numel = x.numel()
sum_dims = [ d for d in range(x.ndim) if d != dim ]
if len(sum_dims) != 0:
x = torch.sum(x, dim=sum_dims)
count = orig_numel // x.numel()
x = x.flatten()