From ec8804283ca0bed0ac0cbd61de430740c3c3d343 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 14 Jan 2023 14:54:46 +0800 Subject: [PATCH] Try to make SmallConvolutionModule more efficient --- .../pruned_transducer_stateless7/zipformer.py | 35 ++++++------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 29b9d84f1..1595d0544 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1008,18 +1008,13 @@ class SmallConvolutionModule(nn.Module): kernel_size=kernel_size, padding=kernel_size // 2) - self.pointwise_conv1 = nn.Conv1d( - channels, - hidden_dim, - kernel_size=1, - stride=1, - padding=0, - bias=True, - ) + self.linear1 = nn.Linear( + channels, hidden_dim) + # balancer and activation as tuned for ConvolutionModule. self.balancer = Balancer( - hidden_dim, channel_dim=1, + hidden_dim, channel_dim=-1, min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), max_positive=1.0, min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), @@ -1028,15 +1023,8 @@ class SmallConvolutionModule(nn.Module): self.activation = SwooshR() - self.pointwise_conv2 = ScaledConv1d( - hidden_dim, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=True, - initial_scale=0.05, - ) + self.linear2 = ScaledLinear(hidden_dim, channels, + initial_scale=0.05) def forward(self, x: Tensor, @@ -1057,16 +1045,13 @@ class SmallConvolutionModule(nn.Module): if src_key_padding_mask is not None: x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - x = self.depthwise_conv(x) # (batch, channels, time) - x = self.pointwise_conv1(x) # (batch, hidden_dim, time) - + x = x.permute(2, 0, 1) # (time, batch, channels) + x = self.linear1(x) # (time, batch, hidden_dim) x = self.balancer(x) x = self.activation(x) - x = self.pointwise_conv2(x) - - return x.permute(2, 0, 1) - + x = self.linear2(x) # (time, batch, channels) + return x class CompactRelPositionalEncoding(torch.nn.Module):