diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index d364560f3..45513ef5c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -122,7 +122,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--feedforward-dim", type=str, - default="1280,1280,1280,1792,1280,1280", + default="1536,1536,2048,1536,1536,1536", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 8c6880595..d2220c787 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -404,8 +404,6 @@ class ZipformerEncoderLayer(nn.Module): self.nonlin_attention_module = NonlinAttentionModule(embed_dim) - self.small_conv_module = SmallConvolutionModule(embed_dim) - self.conv_module = ConvolutionModule(embed_dim, cnn_module_kernel) @@ -471,8 +469,6 @@ class ZipformerEncoderLayer(nn.Module): # multi-headed self-attention module use_self_attn = (random.random() >= dynamic_skip_rate) - src = src + self.feed_forward1(src) - if torch.jit.is_scripting() or use_self_attn: # attn_weights: (num_heads, batch_size, seq_len, seq_len) attn_weights = self.self_attn_weights( @@ -487,8 +483,7 @@ class ZipformerEncoderLayer(nn.Module): attn_weights[0:1]) - if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate: - src = src + self.small_conv_module(src, src_key_padding_mask=src_key_padding_mask) + src = src + self.feed_forward1(src) # pooling module if torch.jit.is_scripting() or use_self_attn: @@ -1575,80 +1570,6 @@ class ConvolutionModule(nn.Module): return x.permute(2, 0, 1) -class SmallConvolutionModule(nn.Module): - """Part of Zipformer model: a small version of the Convolution module that uses a small kernel. - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, channels: int, hidden_dim: int = 256, - ) -> None: - super().__init__() - - self.conv1 = nn.Conv1d( - channels, - hidden_dim, - kernel_size=3, - stride=1, - padding=1, - bias=True, - ) - - self.deriv_balancer = ActivationBalancer( - hidden_dim, channel_dim=1, - min_positive=0.05, max_positive=1.0, - max_abs=20.0, - ) - - self.activation = DoubleSwish() - - self.conv2 = ScaledConv1d( - hidden_dim, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=True, - initial_scale=0.05, - ) - - def forward(self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains bool in masked positions. - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - x = self.conv1(x) # (batch, hidden_dim, time) - - x = self.deriv_balancer(x) - x = self.activation(x) - - if src_key_padding_mask is not None: - x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - x = self.conv2(x) - - return x.permute(2, 0, 1) - - - class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/2 length).