Fix bug when channel_dim < 0

This commit is contained in:
Daniel Povey 2022-10-13 13:40:43 +08:00
parent 49c6b6943d
commit 23d6bf7765

View File

@ -320,7 +320,10 @@ class ActivationBalancer(torch.nn.Module):
channel.
"""
with torch.no_grad():
sum_dims = [d for d in range(x.ndim) if d != self.channel_dim]
channel_dim = self.channel_dim
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
# the random.random() thing is to split the difference if x is zero,