From 8a095c1cd1de2c541f855f47d0498e37ffb35493 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Nov 2022 12:46:40 +0800 Subject: [PATCH] Add SmallConvModule; decrease feedforward dims to keep about same num params. --- .../ASR/pruned_transducer_stateless7/train.py | 2 +- .../pruned_transducer_stateless7/zipformer.py | 80 +++++++++++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 45513ef5c..8d1c65362 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -122,7 +122,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--feedforward-dim", type=str, - default="1536,1536,2048,1536,1536,1536", + default="1280,1280,1536,1280,1280,1280", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c6d44bbe1..994150b2e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -405,6 +405,8 @@ class ZipformerEncoderLayer(nn.Module): self.nonlin_attention_module = NonlinAttentionModule(embed_dim) + self.small_conv_module = SmallConvolutionModule(embed_dim) + self.conv_module = ConvolutionModule(embed_dim, cnn_module_kernel) @@ -483,6 +485,10 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.nonlin_attention_module(src, attn_weights[0:1]) + + if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate: + src = src + self.small_conv_module(src, src_key_padding_mask=src_key_padding_mask) + src = src + self.feed_forward1(src) # pooling module @@ -1569,6 +1575,80 @@ class ConvolutionModule(nn.Module): return x.permute(2, 0, 1) +class SmallConvolutionModule(nn.Module): + """Part of Zipformer model: a small version of the Convolution module that uses a small kernel. + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, channels: int, hidden_dim: int = 256, + ) -> None: + super().__init__() + + self.conv1 = nn.Conv1d( + channels, + hidden_dim, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + self.deriv_balancer = ActivationBalancer( + hidden_dim, channel_dim=1, + min_positive=0.05, max_positive=1.0, + max_abs=20.0, + ) + + self.activation = DoubleSwish() + + self.conv2 = ScaledConv1d( + hidden_dim, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=True, + initial_scale=0.05, + ) + + def forward(self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains bool in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + x = self.conv1(x) # (batch, hidden_dim, time) + + x = self.deriv_balancer(x) + x = self.activation(x) + + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x = self.conv2(x) + + return x.permute(2, 0, 1) + + + class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/2 length).