Try to make SmallConvolutionModule more efficient

This commit is contained in:
Daniel Povey 2023-01-14 14:54:46 +08:00
parent 167b58baa0
commit ec8804283c

View File

@ -1008,18 +1008,13 @@ class SmallConvolutionModule(nn.Module):
kernel_size=kernel_size,
padding=kernel_size // 2)
self.pointwise_conv1 = nn.Conv1d(
channels,
hidden_dim,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.linear1 = nn.Linear(
channels, hidden_dim)
# balancer and activation as tuned for ConvolutionModule.
self.balancer = Balancer(
hidden_dim, channel_dim=1,
hidden_dim, channel_dim=-1,
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
max_positive=1.0,
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
@ -1028,15 +1023,8 @@ class SmallConvolutionModule(nn.Module):
self.activation = SwooshR()
self.pointwise_conv2 = ScaledConv1d(
hidden_dim,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=True,
initial_scale=0.05,
)
self.linear2 = ScaledLinear(hidden_dim, channels,
initial_scale=0.05)
def forward(self,
x: Tensor,
@ -1057,16 +1045,13 @@ class SmallConvolutionModule(nn.Module):
if src_key_padding_mask is not None:
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x) # (batch, channels, time)
x = self.pointwise_conv1(x) # (batch, hidden_dim, time)
x = x.permute(2, 0, 1) # (time, batch, channels)
x = self.linear1(x) # (time, batch, hidden_dim)
x = self.balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
return x.permute(2, 0, 1)
x = self.linear2(x) # (time, batch, channels)
return x
class CompactRelPositionalEncoding(torch.nn.Module):