diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 628d31d4b..eb937e0c3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -954,8 +954,9 @@ class Conv2dSubsampling(nn.Module): def __init__(self, in_channels: int, out_channels: int, - layer1_channels: int = 32, - layer2_channels: int = 128) -> None: + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128) -> None: """ Args: in_channels: @@ -973,7 +974,7 @@ class Conv2dSubsampling(nn.Module): self.conv = nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, - kernel_size=3, stride=2 + kernel_size=3, ), ActivationBalancer(channel_dim=1), DoubleSwish(), @@ -983,8 +984,14 @@ class Conv2dSubsampling(nn.Module): ), ActivationBalancer(channel_dim=1), DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, out_channels=layer3_channels, + kernel_size=3, stride=2 + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), ) - self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) + self.out = ScaledLinear(layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) # set learn_eps=False because out_norm is preceded by `out`, and `out` # itself has learned scale, so the extra degree of freedom is not # needed.