mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix bug when channel_dim < 0
This commit is contained in:
parent
9270e32a51
commit
b09a1b2ae6
@ -325,7 +325,10 @@ class ActivationBalancer(torch.nn.Module):
|
||||
channel.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
sum_dims = [d for d in range(x.ndim) if d != self.channel_dim]
|
||||
channel_dim = self.channel_dim
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
||||
|
||||
x_mean = torch.mean(x, dim=sum_dims).to(torch.float32)
|
||||
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user