diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 697904ca0..d37910b12 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -406,6 +406,7 @@ class ZipformerEncoderLayer(nn.Module): # to work correctly. layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), + small_conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.2), (16000, 0.1), default=0), conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), ff2_skip_rate: FloatLike = 0.01, @@ -422,6 +423,11 @@ class ZipformerEncoderLayer(nn.Module): # an additional skip probability that applies to ConvModule to stop it from # contributing too much early on. self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # skip rate for small_conv_module; it is fairly high and remains nonzero + # because we don't want this submodule to contribute too much. + self.small_conv_skip_rate = copy.deepcopy(small_conv_skip_rate) + # ff2_skip_rate is to prevent the ff2 module from having output that's too big # compared to its residual. self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) @@ -452,6 +458,7 @@ class ZipformerEncoderLayer(nn.Module): self.nonlin_attention = NonlinAttention(embed_dim, hidden_channels=embed_dim // 4) + self.small_conv_module = SmallConvolutionModule(embed_dim) self.conv_module = ConvolutionModule(embed_dim, cnn_module_kernel) @@ -596,6 +603,10 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.balancer_na(self.nonlin_attention(src, selected_attn_weights[0:1])) + + if torch.jit.is_scripting() or random.random() >= float(self.small_conv_skip_rate): + src = src + self.small_conv_module(src, src_key_padding_mask=src_key_padding_mask) + src = src + self.feed_forward1(src) # pooling module @@ -937,6 +948,92 @@ class SimpleCombiner(torch.nn.Module): +class SmallConvolutionModule(nn.Module): + """Part of Zipformer model: a small version of the Convolution module that uses a small kernel. + Inspired by convnext (i.e. have the depthwise conv first.) + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, channels: int, + hidden_dim: int = 128, + kernel_size: int = 5, + ) -> None: + super().__init__() + + + self.depthwise_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2) + + self.pointwise_conv1 = nn.Conv1d( + channels, + hidden_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + # balancer and activation as tuned for ConvolutionModule. + + self.balancer = Balancer( + hidden_dim, channel_dim=1, + min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + max_positive=1.0, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), + max_abs=10.0, + ) + + self.activation = SwooshR() + + self.pointwise_conv2 = ScaledConv1d( + hidden_dim, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=True, + initial_scale=0.05, + ) + + def forward(self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains bool in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x = self.depthwise_conv(x) # (batch, channels, time) + x = self.pointwise_conv1(x) # (batch, hidden_dim, time) + + x = self.balancer(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + return x.permute(2, 0, 1) + + + class CompactRelPositionalEncoding(torch.nn.Module): """ Relative positional encoding module. This version is "compact" meaning it is able to encode @@ -1438,11 +1535,11 @@ class FeedforwardModule(nn.Module): self.in_proj = nn.Linear(embed_dim, feedforward_dim) self.hidden_balancer = Balancer(feedforward_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0) + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0) self.activation = SwooshL() self.dropout = Dropout2(dropout) self.out_proj = ScaledLinear(feedforward_dim, embed_dim,