From 5223286424a1da36c241060e4f7c2304b8a2575e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 2 Jan 2023 14:47:28 +0800 Subject: [PATCH 1/2] Add SmallConvolutionModule --- .../pruned_transducer_stateless7/zipformer.py | 110 +++++++++++++++++- 1 file changed, 104 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index bc4a49d29..dc2eb08ae 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -403,6 +403,7 @@ class ZipformerEncoderLayer(nn.Module): # to work correctly. layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), + small_conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.3), (4000.0, 0.1), (16000, 0.05), default=0), conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), ff2_skip_rate: FloatLike = 0.01, @@ -419,6 +420,11 @@ class ZipformerEncoderLayer(nn.Module): # an additional skip probability that applies to ConvModule to stop it from # contributing too much early on. self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # skip rate for small_conv_module; it is fairly high and remains nonzero + # because we don't want this submodule to contribute too much. + self.small_conv_skip_rate = copy.deepcopy(small_conv_skip_rate) + # ff2_skip_rate is to prevent the ff2 module from having output that's too big # compared to its residual. self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) @@ -447,9 +453,11 @@ class ZipformerEncoderLayer(nn.Module): dropout) self.nonlin_attention = NonlinAttention(embed_dim, - hidden_channels=embed_dim // 4) + hidden_channels=embed_dim // 4) + self.small_conv_module = SmallConvolutionModule(embed_dim) + self.conv_module = ConvolutionModule(embed_dim, cnn_module_kernel) @@ -593,6 +601,10 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.balancer_na(self.nonlin_attention(src, selected_attn_weights[0:1])) + + if torch.jit.is_scripting() or random.random() >= float(self.small_conv_skip_rate): + src = src + self.small_conv_module(src, src_key_padding_mask=src_key_padding_mask) + src = src + self.feed_forward1(src) # pooling module @@ -934,6 +946,92 @@ class SimpleCombiner(torch.nn.Module): +class SmallConvolutionModule(nn.Module): + """Part of Zipformer model: a small version of the Convolution module that uses a small kernel. + Inspired by convnext (i.e. have the depthwise conv first.) + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, channels: int, + hidden_dim: int = 128, + kernel_size: int = 5, + ) -> None: + super().__init__() + + + self.depthwise_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2) + + self.pointwise_conv1 = nn.Conv1d( + channels, + hidden_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + # balancer and activation as tuned for ConvolutionModule. + + self.balancer = Balancer( + hidden_dim, channel_dim=1, + min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + max_positive=1.0, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), + max_abs=10.0, + ) + + self.activation = SwooshR() + + self.pointwise_conv2 = ScaledConv1d( + hidden_dim, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=True, + initial_scale=0.05, + ) + + def forward(self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains bool in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x = self.depthwise_conv(x) # (batch, channels, time) + x = self.pointwise_conv1(x) # (batch, hidden_dim, time) + + x = self.balancer(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + return x.permute(2, 0, 1) + + + class CompactRelPositionalEncoding(torch.nn.Module): """ Relative positional encoding module. This version is "compact" meaning it is able to encode @@ -1431,11 +1529,11 @@ class FeedforwardModule(nn.Module): self.in_proj = nn.Linear(embed_dim, feedforward_dim) self.hidden_balancer = Balancer(feedforward_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0) + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0) self.activation = SwooshL() self.dropout = Dropout2(dropout) self.out_proj = ScaledLinear(feedforward_dim, embed_dim, From f7d67f5456cca894350e73e18a274c243950b4d5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 2 Jan 2023 14:58:23 +0800 Subject: [PATCH 2/2] Higher dropout schedule for SmallConvolutionModule --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index dc2eb08ae..12c60c027 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -403,7 +403,7 @@ class ZipformerEncoderLayer(nn.Module): # to work correctly. layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - small_conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.3), (4000.0, 0.1), (16000, 0.05), default=0), + small_conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.2), (16000, 0.1), default=0), conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), ff2_skip_rate: FloatLike = 0.01,