diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py index 2dff91805..1a2324775 100644 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -56,7 +56,8 @@ def get_sum_abs_stats(x: Tensor, dim: int, x = (x > 0).to(dtype=torch.float) orig_numel = x.numel() sum_dims = [ d for d in range(x.ndim) if d != dim ] - x = torch.sum(x, dim=sum_dims) + if len(sum_dims) != 0: + x = torch.sum(x, dim=sum_dims) count = orig_numel // x.numel() x = x.flatten() return x, count