diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 0d3b0aa02..d8b184752 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -951,8 +951,9 @@ class Conv2dSubsampling(nn.Module): def __init__(self, in_channels: int, out_channels: int, - layer1_channels: int = 64, - layer2_channels: int = 128) -> None: + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128) -> None: """ Args: in_channels: @@ -976,7 +977,7 @@ class Conv2dSubsampling(nn.Module): self.conv = nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, - kernel_size=3, stride=2, + kernel_size=3, padding=1, initial_speed=initial_speed, ), ActivationBalancer(channel_dim=1), @@ -988,8 +989,15 @@ class Conv2dSubsampling(nn.Module): ), ActivationBalancer(channel_dim=1), DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, out_channels=layer3_channels, + kernel_size=3, stride=2, + initial_speed=initial_speed, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), ) - self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) + self.out = ScaledLinear(layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) # set learn_eps=False because out_norm is preceded by `out`, and `out` # itself has learned scale, so the extra degree of freedom is not # needed.