diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index f0a52f605..9f445640c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -735,61 +735,6 @@ class DownsampledZipformerEncoder(nn.Module): return self.out_combiner(src_orig, src) -class DownsamplingZipformerEncoder(nn.Module): - r""" - DownsamplingZipformerEncoder is a zipformer encoder that downsamples its input - by a specified factor before feeding it to the zipformer layers. - """ - def __init__(self, - encoder: nn.Module, - input_dim: int, - output_dim: int, - downsample: int): - super(DownsampledZipformerEncoder, self).__init__() - self.downsample_factor = downsample - self.downsample = AttentionDownsample(input_dim, output_dim, downsample) - self.encoder = encoder - - - def forward(self, - src: Tensor, - feature_mask: Union[Tensor, float] = 1.0, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer. feature_mask is expected to be already downsampled by - self.downsample_factor. - mask: the mask for the src sequence (optional). CAUTION: we need to downsample - this, if we are to support it. Won't work correctly yet. - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - Returns: output of shape (S, N, F) where F is the number of output features - (output_dim to constructor) - """ - src_orig = src - src = self.downsample(src) - ds = self.downsample_factor - if mask is not None: - mask = mask[::ds,::ds] - if src_key_padding_mask is not None: - src_key_padding_mask = src_key_padding_mask[::ds] - - src = self.encoder( - src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, - ) - return src - class AttentionDownsample(torch.nn.Module): """ @@ -1734,6 +1679,71 @@ class ScalarMultiply(nn.Module): def forward(self, x): return x * self.scale + + +class ConvNeXt(nn.Module): + """ + Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf + """ + def __init__(self, + channels: int, + hidden_ratio: int = 4, + layerdrop_prob: FloatLike = None): + super().__init__() + kernel_size = 7 + pad = (kernel_size - 1) // 2 + hidden_channels = channels * hidden_ratio + if layerdrop_prob is None: + layerdrop_prob = ScheduledFloat((0.0, 0.1), (16000.0, 0.01)) + self.layerdrop_prob = layerdrop_prob + + self.depthwise_conv = nn.Conv2d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=7, + padding=(3, 3)) + + self.pointwise_conv1 = nn.Conv2d( + in_channels=channels, + out_channels=hidden_channels, + kernel_size=1) + + self.hidden_balancer = ActivationBalancer(hidden_channels, + channel_dim=1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + min_prob=0.25) + self.activation = SwooshL() + self.pointwise_conv2 = ScaledConv2d( + in_channels=hidden_channels, + out_channels=channels, + kernel_size=1, + initial_scale=0.01) + + + def forward(self, x: Tensor) -> Tensor: + """ + x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) + + The returned value has the same shape as x. + """ + if (not self.training) or torch.jit.is_scripting() or random.random() < float(self.layerdrop_prob): + return x + + bypass = x + x = self.depthwise_conv(x) + x = self.pointwise_conv1(x) + x = self.hidden_balancer(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + return bypass + x + + + + class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/2 length). @@ -1752,7 +1762,6 @@ class Conv2dSubsampling(nn.Module): layer1_channels: int = 8, layer2_channels: int = 32, layer3_channels: int = 128, - bottleneck_channels: int = 64, dropout: FloatLike = 0.1, ) -> None: """ @@ -1778,7 +1787,7 @@ class Conv2dSubsampling(nn.Module): # training. (The second one is necessary to stop its bias from getting # a too-large gradient). - self.conv = nn.Sequential( + self.conv1 = nn.Sequential( nn.Conv2d( in_channels=1, out_channels=layer1_channels, @@ -1797,21 +1806,29 @@ class Conv2dSubsampling(nn.Module): stride=2, padding=0, ), - ActivationBalancer(layer2_channels, - channel_dim=1, - max_abs=4.0), - SwooshR(), - nn.Conv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=(1, 2), # (time, freq) - ), - ActivationBalancer(layer3_channels, - channel_dim=1, - max_abs=4.0), - SwooshR(), ) + + self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels), + ConvNeXt(layer2_channels), + ConvNeXt(layer2_channels), + BasicNorm(layer2_channels, + channel_dim=1)) + + self.conv2 = nn.Sequential( + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + )) + + self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels), + ConvNeXt(layer3_channels), + ConvNeXt(layer3_channels), + BasicNorm(layer3_channels, + channel_dim=1)) + + out_height = (((in_channels - 1) // 2) - 1) // 2 self.scale = nn.Parameter(torch.ones(out_height * layer3_channels)) @@ -1839,7 +1856,12 @@ class Conv2dSubsampling(nn.Module): # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite # gradients. - x = self.conv(x) + x = self.conv1(x) + x = self.convnext1(x) + x = self.conv2(x) + x = self.convnext2(x) + + # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) b, c, t, f = x.size()