diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 9e8a55407..7dede581a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -320,7 +320,10 @@ class ActivationBalancer(torch.nn.Module): channel. """ with torch.no_grad(): - sum_dims = [d for d in range(x.ndim) if d != self.channel_dim] + channel_dim = self.channel_dim + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) # the random.random() thing is to split the difference if x is zero,