Try to make SmallConvolutionModule more efficient

This commit is contained in:
Daniel Povey 2023-01-14 14:54:46 +08:00
parent 167b58baa0
commit ec8804283c

View File

@ -1008,18 +1008,13 @@ class SmallConvolutionModule(nn.Module):
kernel_size=kernel_size, kernel_size=kernel_size,
padding=kernel_size // 2) padding=kernel_size // 2)
self.pointwise_conv1 = nn.Conv1d( self.linear1 = nn.Linear(
channels, channels, hidden_dim)
hidden_dim,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
# balancer and activation as tuned for ConvolutionModule. # balancer and activation as tuned for ConvolutionModule.
self.balancer = Balancer( self.balancer = Balancer(
hidden_dim, channel_dim=1, hidden_dim, channel_dim=-1,
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
max_positive=1.0, max_positive=1.0,
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
@ -1028,15 +1023,8 @@ class SmallConvolutionModule(nn.Module):
self.activation = SwooshR() self.activation = SwooshR()
self.pointwise_conv2 = ScaledConv1d( self.linear2 = ScaledLinear(hidden_dim, channels,
hidden_dim, initial_scale=0.05)
channels,
kernel_size=1,
stride=1,
padding=0,
bias=True,
initial_scale=0.05,
)
def forward(self, def forward(self,
x: Tensor, x: Tensor,
@ -1057,16 +1045,13 @@ class SmallConvolutionModule(nn.Module):
if src_key_padding_mask is not None: if src_key_padding_mask is not None:
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x) # (batch, channels, time) x = self.depthwise_conv(x) # (batch, channels, time)
x = self.pointwise_conv1(x) # (batch, hidden_dim, time) x = x.permute(2, 0, 1) # (time, batch, channels)
x = self.linear1(x) # (time, batch, hidden_dim)
x = self.balancer(x) x = self.balancer(x)
x = self.activation(x) x = self.activation(x)
x = self.pointwise_conv2(x) x = self.linear2(x) # (time, batch, channels)
return x
return x.permute(2, 0, 1)
class CompactRelPositionalEncoding(torch.nn.Module): class CompactRelPositionalEncoding(torch.nn.Module):