From b09a1b2ae6afccf5694f089ac7a6aaf7404ce615 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 13 Oct 2022 13:40:43 +0800 Subject: [PATCH] Fix bug when channel_dim < 0 --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 7034987d9..2f8a88681 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -325,7 +325,10 @@ class ActivationBalancer(torch.nn.Module): channel. """ with torch.no_grad(): - sum_dims = [d for d in range(x.ndim) if d != self.channel_dim] + channel_dim = self.channel_dim + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] x_mean = torch.mean(x, dim=sum_dims).to(torch.float32) x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)