Fix bug when channel_dim < 0
This commit is contained in:
parent
9270e32a51
commit
b09a1b2ae6
@ -325,7 +325,10 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
channel.
|
channel.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
sum_dims = [d for d in range(x.ndim) if d != self.channel_dim]
|
channel_dim = self.channel_dim
|
||||||
|
if channel_dim < 0:
|
||||||
|
channel_dim += x.ndim
|
||||||
|
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||||
|
|
||||||
x_mean = torch.mean(x, dim=sum_dims).to(torch.float32)
|
x_mean = torch.mean(x, dim=sum_dims).to(torch.float32)
|
||||||
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user