diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 4e8deb88f..5a0d22d86 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1704,10 +1704,10 @@ class ConvNeXt(nn.Module): def __init__(self, channels: int, hidden_ratio: int = 3, + kernel_size: Tuple[int, int] = (7, 7), layerdrop_rate: FloatLike = None): super().__init__() - kernel_size = 7 - pad = (kernel_size - 1) // 2 + padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) hidden_channels = channels * hidden_ratio if layerdrop_rate is None: layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015)) @@ -1717,8 +1717,8 @@ class ConvNeXt(nn.Module): in_channels=channels, out_channels=channels, groups=channels, - kernel_size=7, - padding=(3, 3)) + kernel_size=kernel_size, + padding=padding) self.pointwise_conv1 = nn.Conv2d( in_channels=channels, @@ -1869,9 +1869,9 @@ class Conv2dSubsampling(nn.Module): SwooshR(), ) - self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels), - ConvNeXt(layer3_channels), - ConvNeXt(layer3_channels)) + self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels, kernel_size=(5, 5)), + ConvNeXt(layer3_channels, kernel_size=(5, 5)), + ConvNeXt(layer3_channels, kernel_size=(5, 5))) out_width = (((in_channels - 1) // 2) - 1) // 2