mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Fix bug in diagnostics
This commit is contained in:
parent
3d9ddc2016
commit
503f8d521c
@ -56,7 +56,8 @@ def get_sum_abs_stats(x: Tensor, dim: int,
|
|||||||
x = (x > 0).to(dtype=torch.float)
|
x = (x > 0).to(dtype=torch.float)
|
||||||
orig_numel = x.numel()
|
orig_numel = x.numel()
|
||||||
sum_dims = [ d for d in range(x.ndim) if d != dim ]
|
sum_dims = [ d for d in range(x.ndim) if d != dim ]
|
||||||
x = torch.sum(x, dim=sum_dims)
|
if len(sum_dims) != 0:
|
||||||
|
x = torch.sum(x, dim=sum_dims)
|
||||||
count = orig_numel // x.numel()
|
count = orig_numel // x.numel()
|
||||||
x = x.flatten()
|
x = x.flatten()
|
||||||
return x, count
|
return x, count
|
||||||
|
Loading…
x
Reference in New Issue
Block a user