diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 2e6087ad5..609e25626 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -130,6 +130,8 @@ class TensorDiagnostic(object): x = x[0] if not isinstance(x, Tensor): return + if x.numel() == 0: # for empty tensor + return x = x.detach().clone() if x.ndim == 0: x = x.unsqueeze(0)