diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 45513ef5c..d364560f3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -122,7 +122,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--feedforward-dim", type=str, - default="1536,1536,2048,1536,1536,1536", + default="1280,1280,1280,1792,1280,1280", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 7f9c7d7fa..8c6880595 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -404,6 +404,8 @@ class ZipformerEncoderLayer(nn.Module): self.nonlin_attention_module = NonlinAttentionModule(embed_dim) + self.small_conv_module = SmallConvolutionModule(embed_dim) + self.conv_module = ConvolutionModule(embed_dim, cnn_module_kernel) @@ -469,6 +471,8 @@ class ZipformerEncoderLayer(nn.Module): # multi-headed self-attention module use_self_attn = (random.random() >= dynamic_skip_rate) + src = src + self.feed_forward1(src) + if torch.jit.is_scripting() or use_self_attn: # attn_weights: (num_heads, batch_size, seq_len, seq_len) attn_weights = self.self_attn_weights( @@ -482,7 +486,9 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.nonlin_attention_module(src, attn_weights[0:1]) - src = src + self.feed_forward1(src) + + if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate: + src = src + self.small_conv_module(src, src_key_padding_mask=src_key_padding_mask) # pooling module if torch.jit.is_scripting() or use_self_attn: @@ -537,7 +543,8 @@ class ZipformerEncoder(nn.Module): final_layerdrop_rate: float = 0.05, ) -> None: super().__init__() - self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15) + self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, + length_factor=3.0) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -890,11 +897,14 @@ class CompactRelPositionalEncoding(torch.nn.Module): embed_dim: Embedding dimension. dropout_rate: Dropout rate. max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. """ def __init__( self, embed_dim: int, dropout_rate: float, - max_len: int = 1000 + max_len: int = 1000, + length_factor: float = 1.0, ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() @@ -902,8 +912,12 @@ class CompactRelPositionalEncoding(torch.nn.Module): assert embed_dim % 2 == 0 self.dropout = torch.nn.Dropout(dropout_rate) self.pe = None + assert length_factor >= 1.0 + self.length_factor = length_factor self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor) -> None: """Reset the positional encodings.""" if self.pe is not None: @@ -933,15 +947,16 @@ class CompactRelPositionalEncoding(torch.nn.Module): # is important. x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length)) - # length_factor is chosen so that the FFT can exactly separate points - # close to the origin (T == 0). So this part of the formulation is not really - # heuristic. - length_factor = self.embed_dim / (2.0 * math.pi) # todo: test this. + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) # note for machine implementations: if atan is not available, we can use: # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) - x_atan = (x_compressed / length_factor).atan() # results between -pi and pi + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi cosines = (x_atan * freqs).cos() sines = (x_atan * freqs).sin() @@ -951,14 +966,6 @@ class CompactRelPositionalEncoding(torch.nn.Module): pe[:, 1::2] = sines pe[:, -1] = 1.0 # for bias. - # if we have the length_factor correct, the cosines around 0 offset (T in the array) - # should be oscillating in sign like -1, 1, -1; and the sines should all be close to - # zero. - #r = 2 - #print("cosines = ", cosines[T-r:T+r,-5:]) - #print("sines = ", sines[T-r:T+r,-5:]) - - self.pe = pe.to(dtype=x.dtype) @@ -1568,6 +1575,80 @@ class ConvolutionModule(nn.Module): return x.permute(2, 0, 1) +class SmallConvolutionModule(nn.Module): + """Part of Zipformer model: a small version of the Convolution module that uses a small kernel. + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, channels: int, hidden_dim: int = 256, + ) -> None: + super().__init__() + + self.conv1 = nn.Conv1d( + channels, + hidden_dim, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + + self.deriv_balancer = ActivationBalancer( + hidden_dim, channel_dim=1, + min_positive=0.05, max_positive=1.0, + max_abs=20.0, + ) + + self.activation = DoubleSwish() + + self.conv2 = ScaledConv1d( + hidden_dim, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=True, + initial_scale=0.05, + ) + + def forward(self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains bool in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + x = self.conv1(x) # (batch, hidden_dim, time) + + x = self.deriv_balancer(x) + x = self.activation(x) + + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x = self.conv2(x) + + return x.permute(2, 0, 1) + + + class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/2 length).