This commit is contained in:
Yifan Yang 2023-06-15 18:01:27 +08:00
parent f44d1b00b1
commit 8936365c5c

View File

@ -2138,7 +2138,7 @@ class ConvolutionModule(nn.Module):
x = self.in_proj(x) # (time, batch, 2*channels)
x, s = x.chunk(2, dim=2)
x, s = x.chunk(2, dim=-1)
s = self.balancer1(s)
s = self.sigmoid(s)
x = self.activation1(x) # identity.
@ -2190,7 +2190,7 @@ class ConvolutionModule(nn.Module):
x = self.in_proj(x) # (time, batch, 2*channels)
x, s = x.chunk(2, dim=-1)
x, s = x.chunk(2, dim=2)
s = self.sigmoid(s)
x = x * s
# (time, batch, channels)