diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index a99ae4f18..9e2d29ab1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1613,6 +1613,49 @@ class ConvolutionModule(nn.Module): x = x.permute(2, 0, 1) # (time, batch, channel) return x + +class SqueezeExcite1d(nn.Module): + def __init__(self, + channels: int, + bottleneck_channels: int): + super().__init__() + self.to_bottleneck_proj = LinearWithAuxLoss(channels, + bottleneck_channels) + + self.bottleneck_activation = TanSwish() + self.from_bottleneck_proj = nn.Linear(bottleneck_channels, + channels) + + self.balancer = ActivationBalancer( + channels, channel_dim=-1, + min_abs=0.05, + max_abs=ScheduledFloat((0.0, 0.2), + (4000.0, 2.0), + (10000.0, 10.0), + default=1.0), + max_factor=0.02, + min_prob=0.1, + ) + self.activation = nn.Sigmoid() + + + + def forward(self, x: Tensor): + """ + x: a Tensor of shape (batch_size, T, channels). + Returns: something with the same shape as x. + """ + # project before mean, needed for LinearWithAuxLoss (or, at least, better) + bottleneck = self.to_bottleneck_proj(x) + # would replace this mean with cumsum for a causal model. + bottleneck = bottleneck.mean(dim=1, keepdim=True) + bottleneck = self.bottleneck_activation(bottleneck) + scale = self.from_bottleneck_proj(bottleneck) + scale = self.balancer(scale) + scale = self.activation(scale) + return x * scale + + class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/2 length). @@ -1631,6 +1674,7 @@ class Conv2dSubsampling(nn.Module): layer1_channels: int = 8, layer2_channels: int = 32, layer3_channels: int = 128, + bottleneck_channels: int = 64, dropout: float = 0.1, ) -> None: """ @@ -1644,6 +1688,8 @@ class Conv2dSubsampling(nn.Module): Number of channels in layer1 layer1_channels: Number of channels in layer2 + bottleneck: + bottleneck dimension for 1d squeeze-excite """ assert in_channels >= 7 super().__init__() @@ -1679,6 +1725,10 @@ class Conv2dSubsampling(nn.Module): DoubleSwish(), ) out_height = (((in_channels - 1) // 2) - 1) // 2 + + self.squeeze_excite = SqueezeExcite1d(out_height * layer3_channels, + bottleneck_channels) + self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.dropout = nn.Dropout(dropout) @@ -1698,7 +1748,11 @@ class Conv2dSubsampling(nn.Module): x = self.conv(x) # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) + + x = x.transpose(1, 2).reshape(b, t, c * f) + # now x: (N, ((T-1)//2 - 1))//2, out_height * layer3_channels)) + x = self.squeeze_excite(x) + x = self.out(x) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) x = self.dropout(x) return x