Fix bug when channel_dim < 0

This commit is contained in:
Daniel Povey 2022-10-13 13:40:43 +08:00
parent 9270e32a51
commit b09a1b2ae6

View File

@ -325,7 +325,10 @@ class ActivationBalancer(torch.nn.Module):
channel.
"""
with torch.no_grad():
sum_dims = [d for d in range(x.ndim) if d != self.channel_dim]
channel_dim = self.channel_dim
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
x_mean = torch.mean(x, dim=sum_dims).to(torch.float32)
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)